{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "27f443f9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sys\n",
    "sys.path.append('/data/run01/sczc619/LML/MetaTSNE')\n",
    "from bsne_utils import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "84312ae6",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 尽量不要自己造轮子，轮子是需要时间、用户检验的。 =〉 改一下GraphTransformerLayer， 配置对BertEncoder就行\n",
    "\n",
    "\n",
    "class GraphTransformerLayer(nn.Module):\n",
    "    def __init__(self, d_model=256, nhead=8, dim_feedforward=768, dropout=0.1):\n",
    "        super().__init__()\n",
    "        self.linear1 = nn.Linear(d_model, dim_feedforward)\n",
    "        self.self_attn = nn.MultiheadAttention(\n",
    "            embed_dim=d_model,\n",
    "            num_heads=nhead,\n",
    "            dropout=dropout,\n",
    "            batch_first=True\n",
    "        )\n",
    "        \n",
    "        self.dropout = nn.Dropout(dropout)\n",
    "        self.linear2 = nn.Linear(dim_feedforward, d_model)\n",
    "        self.norm1 = nn.LayerNorm(d_model)\n",
    "        self.norm2 = nn.LayerNorm(d_model)\n",
    "        self.dropout1 = nn.Dropout(dropout)\n",
    "        self.dropout2 = nn.Dropout(dropout)\n",
    "        self.activation = nn.ReLU()\n",
    "\n",
    "    def forward(self, src, adj_mask=None):\n",
    "        attn_mask = self._create_attention_mask(adj_mask)\n",
    "\n",
    "        src2 = self.norm1(src)\n",
    "        src2, attn_weights = self.self_attn(\n",
    "            src2, src2, src2,\n",
    "            attn_mask=attn_mask\n",
    "        )\n",
    "        src = src + self.dropout1(src2)\n",
    "\n",
    "        src2 = self.norm2(src)\n",
    "        src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))\n",
    "        src = src + self.dropout2(src2)\n",
    "        return src\n",
    "\n",
    "    def _create_attention_mask(self, adj_mask):\n",
    "        if adj_mask is None:\n",
    "            return None\n",
    "        mask = (adj_mask == 0).bool()\n",
    "        mask = mask.repeat(self.self_attn.num_heads, 1, 1)\n",
    "        return mask\n",
    "\n",
    "\n",
    "class multi_HOGRL_Transformer(nn.Module):\n",
    "    def __init__(self, in_feat, out_feat, relation_nums=3, d_model=256,\n",
    "                 nhead=8, num_layers=5, dim_feedforward=768,\n",
    "                 drop_rate=0.6, layers_tree=2, tsne_weight=0.1):\n",
    "        super().__init__()\n",
    "        self.relation_nums = relation_nums\n",
    "        self.d_model = d_model\n",
    "\n",
    "        self.feature_proj = nn.Linear(in_feat, d_model)\n",
    "\n",
    "        self.transformer_layers = nn.ModuleList([\n",
    "            nn.ModuleList([GraphTransformerLayer(\n",
    "                d_model=d_model,\n",
    "                nhead=nhead,\n",
    "                dim_feedforward=dim_feedforward,\n",
    "                dropout=drop_rate\n",
    "            ) for _ in range(num_layers)])\n",
    "            for _ in range(relation_nums)\n",
    "        ])\n",
    "\n",
    "        self.tree_projs = nn.ModuleList([\n",
    "            nn.ModuleList([nn.Sequential(\n",
    "                nn.Linear(d_model, dim_feedforward),\n",
    "                nn.ReLU(),\n",
    "                nn.Linear(dim_feedforward, d_model)\n",
    "            ) for _ in range(layers_tree)])\n",
    "            for _ in range(relation_nums)\n",
    "        ])\n",
    "\n",
    "        self.gating_networks = nn.ModuleList([\n",
    "            nn.ModuleList([nn.Linear(d_model, 1)\n",
    "                           for _ in range(layers_tree)])\n",
    "            for _ in range(relation_nums)\n",
    "        ])\n",
    "\n",
    "        self.classifier = nn.Sequential(\n",
    "            nn.Linear(relation_nums * d_model, 512),\n",
    "            nn.ReLU(),\n",
    "            nn.Dropout(drop_rate),\n",
    "            nn.Linear(512, out_feat)\n",
    "        )\n",
    "\n",
    "        self.tsne_weight = tsne_weight\n",
    "\n",
    "    def forward(self, x, edge_indexs, sub_nodes=None):\n",
    "        if sub_nodes is not None:\n",
    "            x = x[sub_nodes]\n",
    "\n",
    "        x = self.feature_proj(x)\n",
    "\n",
    "        relation_outputs = []\n",
    "        for rel_idx in range(self.relation_nums):\n",
    "            edge_index = edge_indexs[rel_idx][0]\n",
    "            adj_matrix = to_dense_adj(edge_index, max_num_nodes=x.size(0))[0]\n",
    "            tree_indices = edge_indexs[rel_idx][1]\n",
    "\n",
    "            h = x\n",
    "            for layer in self.transformer_layers[rel_idx]:\n",
    "                h = layer(h, adj_matrix.bool())\n",
    "\n",
    "            tree_features = []\n",
    "            for tree_idx, tree_edges in enumerate(tree_indices):\n",
    "                tree_adj = to_dense_adj(tree_edges, max_num_nodes=x.size(0))[0]\n",
    "                h_tree = x\n",
    "                for layer in self.tree_projs[rel_idx][tree_idx]:\n",
    "                    h_tree = layer(h_tree)\n",
    "                h_tree = self.transformer_layers[rel_idx][-1](h_tree, tree_adj.bool())\n",
    "                tree_features.append(h_tree)\n",
    "\n",
    "            gates = torch.stack([\n",
    "                self.gating_networks[rel_idx][i](feat)\n",
    "                for i, feat in enumerate(tree_features)\n",
    "            ], dim=-1)\n",
    "            alpha = F.softmax(gates, dim=-1)\n",
    "            fused_tree = sum([feat * alpha[:, :, i]\n",
    "                              for i, feat in enumerate(tree_features)])\n",
    "\n",
    "            relation_output = h + fused_tree\n",
    "            relation_outputs.append(relation_output)\n",
    "\n",
    "        combined = torch.cat(relation_outputs, dim=-1)\n",
    "\n",
    "        logits = self.classifier(combined)\n",
    "        logits = F.log_softmax(logits, dim=-1)\n",
    "\n",
    "        tsne_feats = torch.stack(relation_outputs, dim=1).mean(dim=1)\n",
    "\n",
    "        return logits, tsne_feats"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "4e5d5c83",
   "metadata": {},
   "outputs": [],
   "source": [
    "def test(idx_eval, y_eval, model, feat_data, edge_indexs, device):\n",
    "    model.eval()\n",
    "    with torch.no_grad():\n",
    "        logits, _ = model(feat_data.to(device), edge_indexs, sub_nodes=None)\n",
    "        x_softmax = torch.exp(logits).cpu().detach()\n",
    "        positive_class_probs = x_softmax[:, 1].numpy()[np.array(idx_eval)]\n",
    "        auc_score = roc_auc_score(np.array(y_eval), np.array(positive_class_probs))\n",
    "        ap_score = average_precision_score(np.array(y_eval), np.array(positive_class_probs))\n",
    "        label_prob = (np.array(positive_class_probs) >= 0.5).astype(int)\n",
    "        f1_score_val = f1_score(np.array(y_eval), label_prob, average='macro')\n",
    "        g_mean = calculate_g_mean(np.array(y_eval), label_prob)\n",
    "\n",
    "    return auc_score, ap_score, f1_score_val, g_mean\n",
    "\n",
    "\n",
    "def calculate_tsne_loss(emb_p, emb_u, dist_sub_p, dist_matrix, batch_p_global, batch_u_global, temperature=1,\n",
    "                        eps=1e-12):\n",
    "\n",
    "    device = emb_p.device\n",
    "    batch_size = emb_p.size(0)\n",
    "\n",
    "    # --- 局部项：基于B_p子图 ---\n",
    "    # 生成子图局部索引映射表\n",
    "    subnode_to_local = {node: i for i, node in enumerate(batch_p_global)}\n",
    "    local_indices = [subnode_to_local[node] for node in batch_p_global]\n",
    "\n",
    "    # 提取局部距离矩阵\n",
    "    dist_p = dist_sub_p[local_indices][:, local_indices]\n",
    "\n",
    "    # 计算P\n",
    "    # P = torch.exp(-dist_p ** 2)\n",
    "    P = (1.0 + dist_p ** 2) ** -1\n",
    "    P.fill_diagonal_(0)\n",
    "    P = (P + P.T) / 2  # 对称化\n",
    "    P = P / (P.sum(dim=1, keepdim=True) + eps)\n",
    "    # P = P / P.sum()\n",
    "    P = torch.clamp(P, min=eps)\n",
    "\n",
    "    # 计算Q\n",
    "    # pairwise_dist = torch.cdist(emb_p, emb_p)\n",
    "    pairwise_dist = torch.cdist(emb_p, emb_p, p=2)\n",
    "    Q = (1.0 + pairwise_dist ** 2 / temperature) ** -1\n",
    "    Q.fill_diagonal_(0)\n",
    "    Q = (Q + Q.T) / 2\n",
    "    Q = Q / (Q.sum(dim=1, keepdim=True) + eps)\n",
    "    # Q = Q / Q.sum()\n",
    "    Q = torch.clamp(Q, min=eps)\n",
    "\n",
    "\n",
    "    # 局部损失：KL散度\n",
    "    loss_local = (torch.log(P) - torch.log(Q)).mean()\n",
    "\n",
    "    # --- 全局项 ---\n",
    "    # 计算emb_p到emb_u的距离（平方欧氏距离）\n",
    "    dist_pu_sq = torch.cdist(emb_p, emb_u, p=2) ** 2\n",
    "    d_bu = (1.0 + dist_pu_sq / temperature) ** -1\n",
    "    d_bu = d_bu.sum(dim=1)\n",
    "\n",
    "    pairwise_dist_sq = pairwise_dist ** 2\n",
    "    d_bp = (1.0 + pairwise_dist_sq / temperature) ** -1\n",
    "    d_bp = d_bp.sum(dim=1) + eps\n",
    "\n",
    "    # 计算k_Bp\n",
    "    p_xi_full = (1.0 + dist_matrix ** 2) ** -1 \n",
    "    sum_p_xi = p_xi_full[batch_p_global][:, batch_p_global].sum(dim=1)\n",
    "    k_Bp = (sum_p_xi / p_xi_full[batch_p_global].sum(dim=1)) * (dist_matrix.shape[0] / batch_size)\n",
    "\n",
    "    ratio = (k_Bp.unsqueeze(1) * d_bu) / d_bp.unsqueeze(1)\n",
    "    loss_global = torch.log(ratio.clamp(min=eps)).mean()\n",
    "\n",
    "#     print(len(k_Bp),f\"k_Bp: {k_Bp}\")\n",
    "#     print(f\"d_bu mean: {d_bu.mean().item()}, d_bu max: {d_bu.max().item()}, d_bu min: {d_bu.min().item()}\")\n",
    "#     print(f\"d_bp mean: {d_bp.mean().item()}, d_bp max: {d_bp.max().item()}, d_bp min: {d_bp.min().item()}\")\n",
    "\n",
    "\n",
    "#     print(f\"P mean: {P.mean().item()}, P max: {P.max().item()}, P min: {P.min().item()}\")\n",
    "#     print(f\"Q mean: {Q.mean().item()}, Q max: {Q.max().item()}, Q min: {Q.min().item()}\")\n",
    "    print(f\"local_loss: {loss_local.item()},global_loss: {loss_global.item()}\")\n",
    "\n",
    "    return loss_local + loss_global"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7b18a3b5",
   "metadata": {
    "scrolled": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "loading data...\n",
      "\n",
      "=== Starting Pretraining ===\n",
      "local_loss: 2.6089541912078857,global_loss: -0.03623811528086662\n",
      "Pretrain Epoch: 000, TSNE Loss: 0.7718\n",
      "local_loss: 2.6296911239624023,global_loss: -0.03309698775410652\n",
      "local_loss: 2.592395067214966,global_loss: -0.035128217190504074\n",
      "local_loss: 2.497469663619995,global_loss: -0.04721358045935631\n",
      "local_loss: 2.4071149826049805,global_loss: -0.051689885556697845\n",
      "local_loss: 2.342468738555908,global_loss: -0.05975615233182907\n",
      "Pretrain Epoch: 005, TSNE Loss: 0.6848\n",
      "local_loss: 2.1912763118743896,global_loss: -0.06866522878408432\n",
      "local_loss: 2.157341480255127,global_loss: -0.08031507581472397\n",
      "local_loss: 2.066950559616089,global_loss: -0.08595868200063705\n",
      "local_loss: 1.9628474712371826,global_loss: -0.0925396978855133\n",
      "local_loss: 1.9139338731765747,global_loss: -0.11043491214513779\n",
      "Pretrain Epoch: 010, TSNE Loss: 0.5410\n",
      "local_loss: 1.7886779308319092,global_loss: -0.1029379591345787\n",
      "local_loss: 1.6421738862991333,global_loss: -0.11174364387989044\n",
      "local_loss: 1.5797291994094849,global_loss: -0.11739421635866165\n",
      "local_loss: 1.4547098875045776,global_loss: -0.12491020560264587\n",
      "local_loss: 1.3633981943130493,global_loss: -0.1403031051158905\n",
      "Pretrain Epoch: 015, TSNE Loss: 0.3669\n",
      "local_loss: 1.2556706666946411,global_loss: -0.17264920473098755\n",
      "local_loss: 1.2637783288955688,global_loss: -0.1907101571559906\n",
      "local_loss: 1.1778030395507812,global_loss: -0.1986800730228424\n",
      "local_loss: 1.0514488220214844,global_loss: -0.21266043186187744\n",
      "local_loss: 0.8888944387435913,global_loss: -0.2144862860441208\n",
      "Pretrain Epoch: 020, TSNE Loss: 0.2023\n",
      "local_loss: 0.7560893893241882,global_loss: -0.2010185718536377\n",
      "local_loss: 0.6116855144500732,global_loss: -0.1927635669708252\n",
      "local_loss: 0.49238044023513794,global_loss: -0.1882706731557846\n",
      "local_loss: 0.4056878983974457,global_loss: -0.18891219794750214\n",
      "local_loss: 0.29519763588905334,global_loss: -0.1839122623205185\n",
      "Pretrain Epoch: 025, TSNE Loss: 0.0334\n",
      "local_loss: 0.2587205171585083,global_loss: -0.20770764350891113\n",
      "local_loss: 0.15794934332370758,global_loss: -0.18359145522117615\n",
      "local_loss: 0.12573345005512238,global_loss: -0.19806931912899017\n",
      "local_loss: 0.07561244815587997,global_loss: -0.19298997521400452\n",
      "local_loss: 0.05120706930756569,global_loss: -0.19678950309753418\n",
      "Pretrain Epoch: 030, TSNE Loss: -0.0437\n",
      "local_loss: 0.03227933496236801,global_loss: -0.18085254728794098\n",
      "local_loss: 0.024579185992479324,global_loss: -0.18697352707386017\n",
      "local_loss: 0.01534364651888609,global_loss: -0.19486567378044128\n",
      "local_loss: 0.010734467767179012,global_loss: -0.19474409520626068\n",
      "local_loss: 0.009387802332639694,global_loss: -0.19445709884166718\n",
      "Pretrain Epoch: 035, TSNE Loss: -0.0555\n",
      "local_loss: 0.005861146841198206,global_loss: -0.19899556040763855\n",
      "local_loss: 0.001441270811483264,global_loss: -0.19080302119255066\n",
      "local_loss: -0.004558151122182608,global_loss: -0.1810920089483261\n",
      "local_loss: -0.010738498531281948,global_loss: -0.17269670963287354\n",
      "local_loss: -0.030052311718463898,global_loss: -0.13592787086963654\n",
      "Pretrain Epoch: 040, TSNE Loss: -0.0498\n",
      "local_loss: -0.03553919494152069,global_loss: -0.11442219465970993\n",
      "local_loss: -0.036999769508838654,global_loss: -0.10196486115455627\n",
      "local_loss: -0.03836686909198761,global_loss: -0.09397732466459274\n",
      "local_loss: -0.037519436329603195,global_loss: -0.07487709075212479\n",
      "local_loss: -0.03636252507567406,global_loss: -0.05281583219766617\n",
      "Pretrain Epoch: 045, TSNE Loss: -0.0268\n",
      "local_loss: -0.037083860486745834,global_loss: -0.03736116364598274\n",
      "local_loss: -0.035679880529642105,global_loss: -0.02693142741918564\n",
      "local_loss: -0.03657175600528717,global_loss: -0.0024685736279934645\n",
      "local_loss: -0.03032858297228813,global_loss: -0.005631816107779741\n",
      "local_loss: -0.03470884636044502,global_loss: 0.032367002218961716\n",
      "Pretrain Epoch: 050, TSNE Loss: -0.0007\n",
      "local_loss: -0.03234309330582619,global_loss: 0.03553229197859764\n",
      "local_loss: -0.026793263852596283,global_loss: 0.04097168520092964\n",
      "local_loss: -0.025374604389071465,global_loss: 0.05160922184586525\n",
      "local_loss: -0.02293187379837036,global_loss: 0.062259241938591\n",
      "local_loss: -0.017543626949191093,global_loss: 0.07119829207658768\n",
      "Pretrain Epoch: 055, TSNE Loss: 0.0161\n",
      "local_loss: -0.017773106694221497,global_loss: 0.0916919931769371\n",
      "Pretrain early stopping at epoch 56\n",
      "\n",
      "=== Starting Fine-tuning ===\n",
      "Epoch: 000 | Loss: 3.7957 | Val AUC: 0.8691 | Val F1: 0.7924\n",
      "Epoch: 005 | Loss: 0.0958 | Val AUC: 0.8873 | Val F1: 0.8529\n",
      "Epoch: 010 | Loss: 0.0852 | Val AUC: 0.9001 | Val F1: 0.9047\n",
      "Epoch: 015 | Loss: 0.0498 | Val AUC: 0.9088 | Val F1: 0.9165\n",
      "Epoch: 020 | Loss: 0.0514 | Val AUC: 0.8941 | Val F1: 0.9021\n",
      "Epoch: 025 | Loss: 0.0735 | Val AUC: 0.9214 | Val F1: 0.9118\n",
      "Epoch: 030 | Loss: 0.0680 | Val AUC: 0.9180 | Val F1: 0.9065\n",
      "Epoch: 035 | Loss: 0.1799 | Val AUC: 0.9065 | Val F1: 0.9015\n",
      "Epoch: 040 | Loss: 0.0527 | Val AUC: 0.9207 | Val F1: 0.9161\n",
      "Epoch: 045 | Loss: 0.1229 | Val AUC: 0.9026 | Val F1: 0.9103\n",
      "Epoch: 050 | Loss: 0.1376 | Val AUC: 0.9202 | Val F1: 0.9119\n",
      "Epoch: 055 | Loss: 0.1824 | Val AUC: 0.8575 | Val F1: 0.8690\n",
      "Epoch: 060 | Loss: 0.0566 | Val AUC: 0.9229 | Val F1: 0.9203\n",
      "Epoch: 065 | Loss: 0.1336 | Val AUC: 0.9197 | Val F1: 0.8953\n",
      "Epoch: 070 | Loss: 0.0393 | Val AUC: 0.9017 | Val F1: 0.8885\n",
      "Epoch: 075 | Loss: 0.2121 | Val AUC: 0.9044 | Val F1: 0.8835\n",
      "Epoch: 080 | Loss: 0.0722 | Val AUC: 0.9321 | Val F1: 0.9166\n",
      "Epoch: 085 | Loss: 0.0777 | Val AUC: 0.9353 | Val F1: 0.9085\n",
      "Epoch: 090 | Loss: 0.1690 | Val AUC: 0.9288 | Val F1: 0.9187\n",
      "Epoch: 095 | Loss: 0.0595 | Val AUC: 0.9138 | Val F1: 0.9081\n",
      "Epoch: 100 | Loss: 0.0883 | Val AUC: 0.9479 | Val F1: 0.9094\n",
      "Epoch: 105 | Loss: 0.0843 | Val AUC: 0.9555 | Val F1: 0.9218\n",
      "Epoch: 110 | Loss: 0.0800 | Val AUC: 0.9483 | Val F1: 0.9145\n",
      "Epoch: 115 | Loss: 0.1322 | Val AUC: 0.9425 | Val F1: 0.9124\n",
      "Epoch: 120 | Loss: 0.1083 | Val AUC: 0.9531 | Val F1: 0.9069\n",
      "Epoch: 125 | Loss: 0.0525 | Val AUC: 0.9113 | Val F1: 0.8739\n",
      "Epoch: 130 | Loss: 0.0657 | Val AUC: 0.9213 | Val F1: 0.8590\n",
      "Epoch: 135 | Loss: 0.1236 | Val AUC: 0.9407 | Val F1: 0.9009\n",
      "Epoch: 140 | Loss: 0.1579 | Val AUC: 0.8798 | Val F1: 0.8665\n",
      "Epoch: 145 | Loss: 0.1409 | Val AUC: 0.9353 | Val F1: 0.9041\n",
      "Epoch: 150 | Loss: 0.0925 | Val AUC: 0.9172 | Val F1: 0.9113\n",
      "Epoch: 155 | Loss: 0.0885 | Val AUC: 0.9485 | Val F1: 0.8990\n",
      "Epoch: 160 | Loss: 0.0809 | Val AUC: 0.9446 | Val F1: 0.9129\n",
      "Epoch 00033: reducing learning rate of group 0 to 2.5000e-04.\n",
      "Epoch: 165 | Loss: 0.0935 | Val AUC: 0.9256 | Val F1: 0.8801\n",
      "Epoch: 170 | Loss: 0.0652 | Val AUC: 0.9450 | Val F1: 0.9160\n",
      "Epoch: 175 | Loss: 0.0619 | Val AUC: 0.9282 | Val F1: 0.9075\n",
      "Epoch: 180 | Loss: 0.2701 | Val AUC: 0.9184 | Val F1: 0.9059\n",
      "Epoch: 185 | Loss: 0.1291 | Val AUC: 0.9373 | Val F1: 0.9069\n",
      "Epoch: 190 | Loss: 0.0422 | Val AUC: 0.9232 | Val F1: 0.9160\n",
      "Epoch: 195 | Loss: 0.1586 | Val AUC: 0.9353 | Val F1: 0.8799\n",
      "Epoch: 200 | Loss: 0.0552 | Val AUC: 0.9395 | Val F1: 0.9129\n",
      "Epoch: 205 | Loss: 0.1025 | Val AUC: 0.9331 | Val F1: 0.9075\n",
      "Epoch: 210 | Loss: 0.0911 | Val AUC: 0.9335 | Val F1: 0.9092\n",
      "Epoch: 215 | Loss: 0.0709 | Val AUC: 0.9241 | Val F1: 0.9047\n",
      "Epoch 00044: reducing learning rate of group 0 to 1.2500e-04.\n",
      "Epoch: 220 | Loss: 0.1099 | Val AUC: 0.9393 | Val F1: 0.9139\n",
      "Epoch: 225 | Loss: 0.0693 | Val AUC: 0.9311 | Val F1: 0.9036\n",
      "Epoch: 230 | Loss: 0.1245 | Val AUC: 0.9286 | Val F1: 0.8883\n",
      "Epoch: 235 | Loss: 0.0882 | Val AUC: 0.9372 | Val F1: 0.9024\n",
      "Epoch: 240 | Loss: 0.0792 | Val AUC: 0.9436 | Val F1: 0.8965\n",
      "Epoch: 245 | Loss: 0.0818 | Val AUC: 0.9485 | Val F1: 0.9072\n",
      "Epoch: 250 | Loss: 0.0648 | Val AUC: 0.9312 | Val F1: 0.9043\n",
      "Epoch: 255 | Loss: 0.1271 | Val AUC: 0.9382 | Val F1: 0.8788\n",
      "Early stopping at epoch 255\n",
      "\n",
      "=== Final Test Results ===\n",
      "Test AUC: 0.9592 | Test AP: 0.8693 | Test F1: 0.9143 | G-mean: 0.8814\n"
     ]
    }
   ],
   "source": [
    "# bsne算法\n",
    "def bsne_main(args):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    # device = torch.device('cpu')\n",
    "    print(device)\n",
    "\n",
    "    timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n",
    "    writer = SummaryWriter(f'runs/{args[\"dataset\"]}_{timestamp}')\n",
    "\n",
    "    print('loading data...')\n",
    "    prefix = \"/data/run01/sczc619/LML/MetaTSNE/data/\"\n",
    "    edge_indexs, feat_data, labels = load_data(args['dataset'], args['layers_tree'], prefix)\n",
    "\n",
    "    np.random.seed(args['seed'])\n",
    "    rd.seed(args['seed'])\n",
    "\n",
    "    if args['dataset'] == 'yelp':\n",
    "        index = list(range(len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels, stratify=labels,\n",
    "                                                                        test_size=args['test_size'], random_state=2,\n",
    "                                                                        shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"Yelp_shortest_distance.pkl\")\n",
    "    elif args['dataset'] == 'amazon':\n",
    "        index = list(range(3305, len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels[3305:],\n",
    "                                                                        stratify=labels[3305:],\n",
    "                                                                        test_size=args['test_size'],\n",
    "                                                                        random_state=2, shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"Amazon_shortest_distance.pkl\")\n",
    "\n",
    "    with open(dist_path, 'rb') as f:\n",
    "        dist_data = pickle.load(f)\n",
    "        dist_matrix = torch.tensor(dist_data['dist_matrix']).to(device)\n",
    "\n",
    "\n",
    "    adj_dict = defaultdict(list)\n",
    "    for rel in edge_indexs:\n",
    "        edge_index = rel[0].cpu().numpy()\n",
    "        for src, dst in zip(edge_index[0], edge_index[1]):\n",
    "            adj_dict[src].append(dst)\n",
    "\n",
    "    gnn_model = multi_HOGRL_Transformer(\n",
    "        in_feat=feat_data.shape[1],\n",
    "        out_feat=2,\n",
    "        relation_nums=len(edge_indexs),\n",
    "        d_model=128,\n",
    "        nhead=args['num_heads'],\n",
    "        num_layers=3,\n",
    "        dim_feedforward=256,\n",
    "        drop_rate=args['drop_rate'],\n",
    "        layers_tree=args['layers_tree'],\n",
    "        tsne_weight=args['tsne_weight']\n",
    "    ).to(device)\n",
    "\n",
    "    for edge_index in edge_indexs:\n",
    "        edge_index[0] = edge_index[0].to(device)\n",
    "        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]\n",
    "    feat_data = torch.tensor(feat_data).float().to(device)\n",
    "\n",
    "\n",
    "    print(\"\\n=== Starting Pretraining ===\")\n",
    "\n",
    "    gnn_model.classifier.requires_grad_(False)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        filter(lambda p: p.requires_grad, gnn_model.parameters()),\n",
    "        lr=args['pretrain_lr'],\n",
    "        weight_decay=5e-5\n",
    "    )\n",
    "    pretrain_best_loss = float('inf')\n",
    "    pretrain_no_improve = 0\n",
    "    pretrain_early_stop = False\n",
    "\n",
    "    for epoch in range(args['pretrain_epochs']):\n",
    "        if pretrain_early_stop:\n",
    "            break\n",
    "\n",
    "        gnn_model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 第一次采样\n",
    "        batch_centers = rd.sample(range(feat_data.shape[0]), args['batch_size'])\n",
    "        sub_nodes_p = sample_subgraph(batch_centers, dist_matrix, args['sample_size'])\n",
    "\n",
    "        # 第二次采样\n",
    "        batch_u_global = np.random.choice(feat_data.shape[0], size=len(sub_nodes_p), replace=False)\n",
    "\n",
    "        # 生成B_p嵌入\n",
    "        feat_sub_p = feat_data[sub_nodes_p]\n",
    "        _, embeddings_p = gnn_model(feat_sub_p, edge_indexs, sub_nodes=None)\n",
    "\n",
    "        # 生成B_u嵌入\n",
    "        feat_u = feat_data[batch_u_global]\n",
    "        with torch.no_grad():\n",
    "            _, embeddings_u = gnn_model(feat_u, edge_indexs, sub_nodes=None)\n",
    "\n",
    "        # 获取B_p子图距离矩阵\n",
    "        dist_sub_p = dist_matrix[sub_nodes_p][:, sub_nodes_p]\n",
    "\n",
    "        # 计算损失\n",
    "        tsne_loss = calculate_tsne_loss(\n",
    "            embeddings_p,  # 子图嵌入\n",
    "            embeddings_u,  # 全局采样嵌入\n",
    "            dist_sub_p,  # B_p子图距离\n",
    "            dist_matrix,  # 全图距离\n",
    "            sub_nodes_p,  # B_p全局索引\n",
    "            batch_u_global,  # B_u全局索引\n",
    "            temperature=100,\n",
    "            eps=1e-10\n",
    "        ) * args['tsne_weight']\n",
    "\n",
    "\n",
    "        tsne_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        # print(f\"Epoch {epoch}: Loss={tsne_loss.item():.4f}\")\n",
    "\n",
    "        if tsne_loss.item() < pretrain_best_loss:\n",
    "            pretrain_best_loss = tsne_loss.item()\n",
    "            pretrain_no_improve = 0\n",
    "        else:\n",
    "            pretrain_no_improve += 1\n",
    "\n",
    "        if pretrain_no_improve >= args['pretrain_patience']:\n",
    "            print(f\"Pretrain early stopping at epoch {epoch}\")\n",
    "            pretrain_early_stop = True\n",
    "\n",
    "        writer.add_scalar('Pretrain/TSNE_Loss', tsne_loss.item(), epoch)\n",
    "\n",
    "        if epoch % 5 == 0:\n",
    "            print(f'Pretrain Epoch: {epoch:03d}, TSNE Loss: {tsne_loss.item():.4f}')\n",
    "\n",
    "    print(\"\\n=== Starting Fine-tuning ===\")\n",
    "    gnn_model.classifier.requires_grad_(True)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        filter(lambda p: p.requires_grad, gnn_model.parameters()),\n",
    "        lr=args['finetune_lr'],\n",
    "        weight_decay=5e-5\n",
    "    )\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "        optimizer, mode='max', factor=0.5, patience=10, verbose=True\n",
    "    )\n",
    "\n",
    "    best_val_auc = 0.0\n",
    "    best_model_state = None\n",
    "    train_pos, train_neg = pos_neg_split(idx_train, y_train)\n",
    "\n",
    "    no_improve_epochs = 0\n",
    "    early_stop = False\n",
    "\n",
    "    for epoch in range(args['num_epochs']):\n",
    "        if early_stop:\n",
    "            break\n",
    "\n",
    "        gnn_model.train()\n",
    "        total_loss = 0.0\n",
    "\n",
    "        batch_centers = rd.sample(train_pos + train_neg, args['batch_size'])\n",
    "        sub_nodes = sample_subgraph(batch_centers, dist_matrix, args['sample_size'])\n",
    "        batch_mask = [i for i, node in enumerate(sub_nodes) if node in batch_centers]\n",
    "\n",
    "        feat_sub = feat_data[sub_nodes]\n",
    "        labels_sub = labels[sub_nodes]\n",
    "\n",
    "        for _ in range(len(sub_nodes) // args['batch_size']):\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            out, _ = gnn_model(feat_sub, edge_indexs, sub_nodes=None)\n",
    "            cls_loss = F.nll_loss(out[batch_mask], torch.LongTensor(labels_sub[batch_mask]).to(device))\n",
    "\n",
    "            cls_loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += cls_loss.item()\n",
    "\n",
    "        avg_loss = total_loss / (len(sub_nodes) // args['batch_size'])\n",
    "        writer.add_scalar('FineTune/Train_Loss', avg_loss, epoch)\n",
    "\n",
    "        if epoch % 5 == 0:\n",
    "            val_auc, val_ap, val_f1, val_g_mean = test(idx_val, y_val, gnn_model, feat_data, edge_indexs, device)\n",
    "\n",
    "            writer.add_scalar('Validation/AUC', val_auc, epoch)\n",
    "            writer.add_scalar('Validation/F1', val_f1, epoch)\n",
    "            writer.add_scalar('Validation/GMean', val_g_mean, epoch)\n",
    "\n",
    "            print(f'Epoch: {epoch:03d} | Loss: {avg_loss:.4f} | Val AUC: {val_auc:.4f} | Val F1: {val_f1:.4f}')\n",
    "\n",
    "            scheduler.step(val_auc)\n",
    "\n",
    "            if val_auc > best_val_auc:\n",
    "                best_val_auc = val_auc\n",
    "                no_improve_epochs = 0\n",
    "                best_model_state = copy.deepcopy(gnn_model.state_dict())\n",
    "            else:\n",
    "                no_improve_epochs += 1\n",
    "\n",
    "            if no_improve_epochs >= args['patience']:\n",
    "                print(f\"Early stopping at epoch {epoch}\")\n",
    "                early_stop = True\n",
    "\n",
    "    gnn_model.load_state_dict(best_model_state)\n",
    "    test_auc, test_ap, test_f1, test_g_mean = test(idx_test, y_test, gnn_model, feat_data, edge_indexs, device)\n",
    "    print(f'\\n=== Final Test Results ===')\n",
    "    print(f'Test AUC: {test_auc:.4f} | Test AP: {test_ap:.4f} | Test F1: {test_f1:.4f} | G-mean: {test_g_mean:.4f}')\n",
    "    writer.close()\n",
    "\n",
    "args = {\n",
    "    \"dataset\": \"amazon\",\n",
    "    \"batch_size\": 128,\n",
    "    \"sample_size\": 50,\n",
    "    \"weight_decay\": 0.00005,\n",
    "    \"emb_size\": 32,\n",
    "    \"pretrain_epochs\": 200,\n",
    "    \"pretrain_lr\": 0.001,\n",
    "    \"finetune_lr\": 0.0005,\n",
    "    \"num_epochs\": 500,\n",
    "    \"pretrain_patience\": 20,\n",
    "    \"patience\": 30,\n",
    "    \"tsne_weight\": 0.3,\n",
    "    \"weight\": 0.6,\n",
    "    \"layers\": 7,\n",
    "    \"test_size\": 0.6,\n",
    "    \"val_size\": 0.5,\n",
    "    \"layers_tree\": 7,\n",
    "    \"seed\": 76,\n",
    "    \"num_heads\": 2,\n",
    "    \"drop_rate\": 0.5\n",
    "}\n",
    "bsne_main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "a3aad308",
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "loading data...\n",
      "\n",
      "=== Starting Fine-tuning ===\n",
      "Epoch: 000 | Loss: 1.8113 | Val AUC: 0.8883 | Val F1: 0.8803\n",
      "Epoch: 005 | Loss: 0.2247 | Val AUC: 0.9002 | Val F1: 0.8857\n",
      "Epoch: 010 | Loss: 0.0837 | Val AUC: 0.9174 | Val F1: 0.8905\n",
      "Epoch: 015 | Loss: 0.0624 | Val AUC: 0.9296 | Val F1: 0.9156\n",
      "Epoch: 020 | Loss: 0.1623 | Val AUC: 0.7703 | Val F1: 0.8370\n",
      "Epoch: 025 | Loss: 0.0796 | Val AUC: 0.9440 | Val F1: 0.9170\n",
      "Epoch: 030 | Loss: 0.1586 | Val AUC: 0.8984 | Val F1: 0.8928\n",
      "Epoch: 035 | Loss: 0.0832 | Val AUC: 0.9203 | Val F1: 0.9203\n",
      "Epoch: 040 | Loss: 0.1095 | Val AUC: 0.9322 | Val F1: 0.9176\n",
      "Epoch: 045 | Loss: 0.1339 | Val AUC: 0.9100 | Val F1: 0.8739\n",
      "Epoch: 050 | Loss: 0.1105 | Val AUC: 0.9339 | Val F1: 0.9160\n",
      "Epoch: 055 | Loss: 0.0749 | Val AUC: 0.9307 | Val F1: 0.9098\n",
      "Epoch: 060 | Loss: 0.1972 | Val AUC: 0.9281 | Val F1: 0.9031\n",
      "Epoch: 065 | Loss: 0.2574 | Val AUC: 0.8758 | Val F1: 0.8524\n",
      "Epoch: 070 | Loss: 0.0861 | Val AUC: 0.9245 | Val F1: 0.8973\n",
      "Epoch: 075 | Loss: 0.1184 | Val AUC: 0.9283 | Val F1: 0.9006\n",
      "Epoch: 080 | Loss: 0.0610 | Val AUC: 0.9166 | Val F1: 0.9030\n",
      "Epoch 00017: reducing learning rate of group 0 to 2.5000e-04.\n",
      "Epoch: 085 | Loss: 0.0502 | Val AUC: 0.9416 | Val F1: 0.9129\n",
      "Epoch: 090 | Loss: 0.0946 | Val AUC: 0.8822 | Val F1: 0.8607\n",
      "Epoch: 095 | Loss: 0.0350 | Val AUC: 0.9419 | Val F1: 0.8996\n",
      "Epoch: 100 | Loss: 0.1188 | Val AUC: 0.9411 | Val F1: 0.8926\n",
      "Epoch: 105 | Loss: 0.0375 | Val AUC: 0.9478 | Val F1: 0.9224\n",
      "Epoch: 110 | Loss: 0.0435 | Val AUC: 0.9383 | Val F1: 0.9140\n",
      "Epoch: 115 | Loss: 0.1565 | Val AUC: 0.9306 | Val F1: 0.9081\n",
      "Epoch: 120 | Loss: 0.2723 | Val AUC: 0.9208 | Val F1: 0.8431\n",
      "Epoch: 125 | Loss: 0.0727 | Val AUC: 0.9407 | Val F1: 0.9086\n",
      "Epoch: 130 | Loss: 0.0308 | Val AUC: 0.9326 | Val F1: 0.8968\n",
      "Epoch: 135 | Loss: 0.0799 | Val AUC: 0.9527 | Val F1: 0.8875\n",
      "Epoch: 140 | Loss: 0.0995 | Val AUC: 0.9500 | Val F1: 0.8990\n",
      "Epoch: 145 | Loss: 0.0403 | Val AUC: 0.9555 | Val F1: 0.9129\n",
      "Epoch: 150 | Loss: 0.0481 | Val AUC: 0.9512 | Val F1: 0.9082\n",
      "Epoch: 155 | Loss: 0.0997 | Val AUC: 0.9247 | Val F1: 0.8924\n",
      "Epoch: 160 | Loss: 0.0604 | Val AUC: 0.9466 | Val F1: 0.9113\n",
      "Epoch: 165 | Loss: 0.0707 | Val AUC: 0.9545 | Val F1: 0.9014\n",
      "Epoch: 170 | Loss: 0.0526 | Val AUC: 0.9480 | Val F1: 0.9108\n",
      "Epoch: 175 | Loss: 0.0229 | Val AUC: 0.9488 | Val F1: 0.9083\n",
      "Epoch: 180 | Loss: 0.2352 | Val AUC: 0.9401 | Val F1: 0.9083\n",
      "Epoch: 185 | Loss: 0.1719 | Val AUC: 0.9333 | Val F1: 0.8573\n",
      "Epoch: 190 | Loss: 0.0139 | Val AUC: 0.9555 | Val F1: 0.9103\n",
      "Epoch: 195 | Loss: 0.0611 | Val AUC: 0.9453 | Val F1: 0.9088\n",
      "Epoch: 200 | Loss: 0.0435 | Val AUC: 0.9305 | Val F1: 0.8990\n",
      "Epoch 00041: reducing learning rate of group 0 to 1.2500e-04.\n",
      "Epoch: 205 | Loss: 0.0872 | Val AUC: 0.9470 | Val F1: 0.9092\n",
      "Epoch: 210 | Loss: 0.0537 | Val AUC: 0.9493 | Val F1: 0.9124\n",
      "Epoch: 215 | Loss: 0.0530 | Val AUC: 0.9458 | Val F1: 0.9140\n",
      "Epoch: 220 | Loss: 0.0531 | Val AUC: 0.9544 | Val F1: 0.9176\n",
      "Epoch: 225 | Loss: 0.0504 | Val AUC: 0.9465 | Val F1: 0.8978\n",
      "Epoch: 230 | Loss: 0.0473 | Val AUC: 0.9480 | Val F1: 0.9026\n",
      "Epoch: 235 | Loss: 0.0402 | Val AUC: 0.9509 | Val F1: 0.9218\n",
      "Epoch: 240 | Loss: 0.0200 | Val AUC: 0.9544 | Val F1: 0.9114\n",
      "Epoch: 245 | Loss: 0.1004 | Val AUC: 0.9209 | Val F1: 0.8697\n",
      "Epoch: 250 | Loss: 0.1040 | Val AUC: 0.9401 | Val F1: 0.9064\n",
      "Epoch: 255 | Loss: 0.0202 | Val AUC: 0.9484 | Val F1: 0.9155\n",
      "Epoch 00052: reducing learning rate of group 0 to 6.2500e-05.\n",
      "Epoch: 260 | Loss: 0.0554 | Val AUC: 0.9512 | Val F1: 0.9021\n",
      "Epoch: 265 | Loss: 0.0192 | Val AUC: 0.9451 | Val F1: 0.8977\n",
      "Epoch: 270 | Loss: 0.0635 | Val AUC: 0.9454 | Val F1: 0.9092\n",
      "Epoch: 275 | Loss: 0.0101 | Val AUC: 0.9484 | Val F1: 0.9097\n",
      "Epoch: 280 | Loss: 0.0070 | Val AUC: 0.9494 | Val F1: 0.8931\n",
      "Epoch: 285 | Loss: 0.0841 | Val AUC: 0.9462 | Val F1: 0.9093\n",
      "Epoch: 290 | Loss: 0.0431 | Val AUC: 0.9530 | Val F1: 0.9104\n",
      "Epoch: 295 | Loss: 0.0705 | Val AUC: 0.9531 | Val F1: 0.8954\n",
      "Epoch: 300 | Loss: 0.0722 | Val AUC: 0.9488 | Val F1: 0.9005\n",
      "Epoch: 305 | Loss: 0.0652 | Val AUC: 0.9435 | Val F1: 0.8950\n",
      "Epoch: 310 | Loss: 0.0328 | Val AUC: 0.9506 | Val F1: 0.8878\n",
      "Epoch 00063: reducing learning rate of group 0 to 3.1250e-05.\n",
      "Epoch: 315 | Loss: 0.0381 | Val AUC: 0.9463 | Val F1: 0.9024\n",
      "Epoch: 320 | Loss: 0.0478 | Val AUC: 0.9504 | Val F1: 0.8975\n",
      "Epoch: 325 | Loss: 0.0214 | Val AUC: 0.9532 | Val F1: 0.9024\n",
      "Epoch: 330 | Loss: 0.0343 | Val AUC: 0.9510 | Val F1: 0.8990\n",
      "Epoch: 335 | Loss: 0.0571 | Val AUC: 0.9448 | Val F1: 0.8969\n",
      "Epoch: 340 | Loss: 0.0131 | Val AUC: 0.9531 | Val F1: 0.9104\n",
      "Early stopping at epoch 340\n",
      "\n",
      "=== Final Test Results ===\n",
      "Test AUC: 0.9640 | Test AP: 0.8766 | Test F1: 0.9111 | G-mean: 0.8829\n"
     ]
    }
   ],
   "source": [
    "# 只进行分类训练（消融实验）\n",
    "def bsne_main(args):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    # device = torch.device('cpu')\n",
    "    print(device)\n",
    "\n",
    "    timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n",
    "    writer = SummaryWriter(f'runs/{args[\"dataset\"]}_{timestamp}')\n",
    "\n",
    "    print('loading data...')\n",
    "    prefix = \"/data/run01/sczc619/LML/MetaTSNE/data/\"\n",
    "    edge_indexs, feat_data, labels = load_data(args['dataset'], args['layers_tree'], prefix)\n",
    "\n",
    "    np.random.seed(args['seed'])\n",
    "    rd.seed(args['seed'])\n",
    "\n",
    "    if args['dataset'] == 'yelp':\n",
    "        index = list(range(len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels, stratify=labels,\n",
    "                                                                        test_size=args['test_size'], random_state=2,\n",
    "                                                                        shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"Yelp_shortest_distance.pkl\")\n",
    "    elif args['dataset'] == 'amazon':\n",
    "        index = list(range(3305, len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels[3305:],\n",
    "                                                                        stratify=labels[3305:],\n",
    "                                                                        test_size=args['test_size'],\n",
    "                                                                        random_state=2, shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"Amazon_shortest_distance.pkl\")\n",
    "\n",
    "    with open(dist_path, 'rb') as f:\n",
    "        dist_data = pickle.load(f)\n",
    "        dist_matrix = torch.tensor(dist_data['dist_matrix']).to(device)\n",
    "\n",
    "\n",
    "    adj_dict = defaultdict(list)\n",
    "    for rel in edge_indexs:\n",
    "        edge_index = rel[0].cpu().numpy()\n",
    "        for src, dst in zip(edge_index[0], edge_index[1]):\n",
    "            adj_dict[src].append(dst)\n",
    "\n",
    "    gnn_model = multi_HOGRL_Transformer(\n",
    "        in_feat=feat_data.shape[1],\n",
    "        out_feat=2,\n",
    "        relation_nums=len(edge_indexs),\n",
    "        d_model=128,\n",
    "        nhead=2,\n",
    "        num_layers=3,\n",
    "        dim_feedforward=256,\n",
    "        drop_rate=args['drop_rate'],\n",
    "        layers_tree=args['layers_tree'],\n",
    "        tsne_weight=args['tsne_weight']\n",
    "    ).to(device)\n",
    "\n",
    "    for edge_index in edge_indexs:\n",
    "        edge_index[0] = edge_index[0].to(device)\n",
    "        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]\n",
    "    feat_data = torch.tensor(feat_data).float().to(device)\n",
    "    \n",
    "    \n",
    "    print(\"\\n=== Starting Fine-tuning ===\")\n",
    "    gnn_model.classifier.requires_grad_(True)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        filter(lambda p: p.requires_grad, gnn_model.parameters()),\n",
    "        lr=args['finetune_lr'],\n",
    "        weight_decay=5e-5\n",
    "    )\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "        optimizer, mode='max', factor=0.5, patience=10, verbose=True\n",
    "    )\n",
    "\n",
    "    best_val_auc = 0.0\n",
    "    best_model_state = None\n",
    "    train_pos, train_neg = pos_neg_split(idx_train, y_train)\n",
    "\n",
    "    no_improve_epochs = 0\n",
    "    early_stop = False\n",
    "\n",
    "    for epoch in range(args['num_epochs']):\n",
    "        if early_stop:\n",
    "            break\n",
    "\n",
    "        gnn_model.train()\n",
    "        total_loss = 0.0\n",
    "\n",
    "        batch_centers = rd.sample(train_pos + train_neg, args['batch_size'])\n",
    "        sub_nodes = sample_subgraph(batch_centers, dist_matrix, args['sample_size'])\n",
    "        batch_mask = [i for i, node in enumerate(sub_nodes) if node in batch_centers]\n",
    "\n",
    "        feat_sub = feat_data[sub_nodes]\n",
    "        labels_sub = labels[sub_nodes]\n",
    "\n",
    "        for _ in range(len(sub_nodes) // args['batch_size']):\n",
    "            optimizer.zero_grad()\n",
    "\n",
    "            out, _ = gnn_model(feat_sub, edge_indexs, sub_nodes=None)\n",
    "            cls_loss = F.nll_loss(out[batch_mask], torch.LongTensor(labels_sub[batch_mask]).to(device))\n",
    "\n",
    "            cls_loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "            total_loss += cls_loss.item()\n",
    "\n",
    "        avg_loss = total_loss / (len(sub_nodes) // args['batch_size'])\n",
    "        writer.add_scalar('FineTune/Train_Loss', avg_loss, epoch)\n",
    "\n",
    "        if epoch % 5 == 0:\n",
    "            val_auc, val_ap, val_f1, val_g_mean = test(idx_val, y_val, gnn_model, feat_data, edge_indexs, device)\n",
    "\n",
    "            writer.add_scalar('Validation/AUC', val_auc, epoch)\n",
    "            writer.add_scalar('Validation/F1', val_f1, epoch)\n",
    "            writer.add_scalar('Validation/GMean', val_g_mean, epoch)\n",
    "\n",
    "            print(f'Epoch: {epoch:03d} | Loss: {avg_loss:.4f} | Val AUC: {val_auc:.4f} | Val F1: {val_f1:.4f}')\n",
    "\n",
    "            scheduler.step(val_auc)\n",
    "\n",
    "            if val_auc > best_val_auc:\n",
    "                best_val_auc = val_auc\n",
    "                no_improve_epochs = 0\n",
    "                best_model_state = copy.deepcopy(gnn_model.state_dict())\n",
    "            else:\n",
    "                no_improve_epochs += 1\n",
    "\n",
    "            if no_improve_epochs >= args['patience']:\n",
    "                print(f\"Early stopping at epoch {epoch}\")\n",
    "                early_stop = True\n",
    "\n",
    "    gnn_model.load_state_dict(best_model_state)\n",
    "    test_auc, test_ap, test_f1, test_g_mean = test(idx_test, y_test, gnn_model, feat_data, edge_indexs, device)\n",
    "    print(f'\\n=== Final Test Results ===')\n",
    "    print(f'Test AUC: {test_auc:.4f} | Test AP: {test_ap:.4f} | Test F1: {test_f1:.4f} | G-mean: {test_g_mean:.4f}')\n",
    "    writer.close()\n",
    "\n",
    "args = {\n",
    "    \"dataset\": \"amazon\",\n",
    "    \"batch_size\": 128,\n",
    "    \"sample_size\": 50,\n",
    "    \"weight_decay\": 0.00005,\n",
    "    \"emb_size\": 32,\n",
    "    \"pretrain_epochs\": 200,\n",
    "    \"pretrain_lr\": 0.001,\n",
    "    \"finetune_lr\": 0.0005,\n",
    "    \"num_epochs\": 500,\n",
    "    \"pretrain_patience\": 20,\n",
    "    \"patience\": 30,\n",
    "    \"tsne_weight\": 0.3,\n",
    "    \"weight\": 0.6,\n",
    "    \"layers\": 7,\n",
    "    \"test_size\": 0.6,\n",
    "    \"val_size\": 0.5,\n",
    "    \"layers_tree\": 7,\n",
    "    \"seed\": 76,\n",
    "    \"num_heads\": 2,\n",
    "    \"drop_rate\": 0.5\n",
    "}\n",
    "bsne_main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "e356ab1b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "cuda\n",
      "loading data...\n",
      "\n",
      "=== Starting Pretraining ===\n",
      "Epoch: 9, Val AUC: 0.6613, Val AP: 0.1657, Val F1: 0.3708, Val G-mean: 0.5626\n",
      "Epoch: 19, Val AUC: 0.8135, Val AP: 0.5044, Val F1: 0.3984, Val G-mean: 0.5935\n",
      "Epoch: 29, Val AUC: 0.8350, Val AP: 0.6637, Val F1: 0.4063, Val G-mean: 0.6013\n",
      "Epoch: 39, Val AUC: 0.8458, Val AP: 0.7216, Val F1: 0.4121, Val G-mean: 0.6078\n",
      "Epoch: 49, Val AUC: 0.8446, Val AP: 0.7134, Val F1: 0.4133, Val G-mean: 0.6092\n",
      "Epoch: 59, Val AUC: 0.8369, Val AP: 0.6643, Val F1: 0.4096, Val G-mean: 0.6050\n"
     ]
    }
   ],
   "source": [
    "# 原sne算法 \n",
    "\n",
    "def calculate_sne_loss(emb, dist_matrix, temperature=1, eps=1e-12):\n",
    "    \"\"\"\n",
    "    :param emb: 全图嵌入 [num_nodes, d_model]\n",
    "    :param dist_matrix: 全图距离矩阵 [num_nodes, num_nodes]\n",
    "    \"\"\"\n",
    "    device = emb.device\n",
    "    num_nodes = emb.size(0)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        P = (1.0 + dist_matrix ** 2) ** -1\n",
    "        P.fill_diagonal_(0)\n",
    "        P = (P + P.T) / 2 \n",
    "        P = P / (P.sum(dim=1, keepdim=True) + eps)\n",
    "        P = torch.clamp(P, min=eps)\n",
    "\n",
    "    pairwise_dist = torch.cdist(emb, emb, p=2)\n",
    "    Q = (1.0 + pairwise_dist ** 2 / temperature) ** -1\n",
    "    Q.fill_diagonal_(0)\n",
    "    Q = (Q + Q.T) / 2\n",
    "    Q = Q / (Q.sum(dim=1, keepdim=True) + eps)\n",
    "    Q = torch.clamp(Q, min=eps)\n",
    "\n",
    "    loss = (P * (torch.log(P) - torch.log(Q))).sum()\n",
    "\n",
    "    return loss\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "def sne_main(args):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    print(device)\n",
    "\n",
    "    timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n",
    "    writer = SummaryWriter(f'runs/{args[\"dataset\"]}_{timestamp}')\n",
    "\n",
    "    print('loading data...')\n",
    "    prefix = \"../../data/\"\n",
    "\n",
    "    edge_indexs, feat_data, labels = load_data(args['dataset'], args['layers_tree'], prefix)\n",
    "\n",
    "    np.random.seed(args['seed'])\n",
    "    rd.seed(args['seed'])\n",
    "\n",
    "    if args['dataset'] == 'yelp':\n",
    "        index = list(range(len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels, stratify=labels,\n",
    "                                                                        test_size=args['test_size'], random_state=2,\n",
    "                                                                        shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"YelpChi_shortest_distance.pkl\")\n",
    "    elif args['dataset'] == 'amazon':\n",
    "        index = list(range(3305, len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels[3305:],\n",
    "                                                                        stratify=labels[3305:],\n",
    "                                                                        test_size=args['test_size'],\n",
    "                                                                        random_state=2, shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"Amazon_shortest_distance.pkl\")\n",
    "\n",
    "    with open(dist_path, 'rb') as f:\n",
    "        dist_data = pickle.load(f)\n",
    "        dist_matrix = torch.tensor(dist_data['dist_matrix']).to(device)\n",
    "\n",
    "    adj_dict = defaultdict(list)\n",
    "    for rel in edge_indexs:\n",
    "        edge_index = rel[0].cpu().numpy()\n",
    "        for src, dst in zip(edge_index[0], edge_index[1]):\n",
    "            adj_dict[src].append(dst)\n",
    "\n",
    "    sne_model = multi_HOGRL_Transformer(\n",
    "        in_feat=feat_data.shape[1],\n",
    "        out_feat=2,\n",
    "        relation_nums=len(edge_indexs),\n",
    "        d_model=64,\n",
    "        nhead=2,\n",
    "        num_layers=1,\n",
    "        dim_feedforward=128,\n",
    "        drop_rate=args['drop_rate'],\n",
    "        layers_tree=args['layers_tree'],\n",
    "        tsne_weight=args['tsne_weight']\n",
    "    ).to(device)\n",
    "\n",
    "    for edge_index in edge_indexs:\n",
    "        edge_index[0] = edge_index[0].to(device)\n",
    "        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]\n",
    "    feat_data = torch.tensor(feat_data).float().to(device)\n",
    "\n",
    "    print(\"\\n=== Starting Pretraining ===\")\n",
    "\n",
    "    sne_model.classifier.requires_grad_(False)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        filter(lambda p: p.requires_grad, sne_model.parameters()),\n",
    "        lr=args['pretrain_lr'],\n",
    "        weight_decay=5e-5\n",
    "    )\n",
    "\n",
    "    best_val_auc = 0.0\n",
    "    best_model_state = None\n",
    "    train_pos, train_neg = pos_neg_split(idx_train, y_train)\n",
    "    \n",
    "    epoch_times = [] \n",
    "    \n",
    "    for epoch in range(args['num_epochs']):\n",
    "        start_time = time.time()\n",
    "        sne_model.train()\n",
    "        loss = 0\n",
    "        sampled_idx_train = undersample(train_pos, train_neg, scale=1)\n",
    "        rd.shuffle(sampled_idx_train)\n",
    "\n",
    "        num_batches = int(len(sampled_idx_train) / args['batch_size']) + 1\n",
    "        for batch in range(num_batches):\n",
    "            i_start = batch * args['batch_size']\n",
    "            i_end = min((batch + 1) * args['batch_size'], len(sampled_idx_train))\n",
    "            batch_nodes = sampled_idx_train[i_start:i_end]\n",
    "            batch_label = torch.tensor(labels[np.array(batch_nodes)]).long().to(device)\n",
    "            optimizer.zero_grad()\n",
    "            out, embedding = sne_model(feat_data, edge_indexs)\n",
    "            batch_nodes_tensor = torch.tensor(batch_nodes, dtype=torch.long, device=device)\n",
    "            tsne_loss = calculate_sne_loss(\n",
    "                embedding,\n",
    "                dist_matrix,\n",
    "                temperature=10,\n",
    "                eps=1e-12\n",
    "            )\n",
    "            loss = tsne_loss * args['tsne_weight']\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "            \n",
    "        end_time = time.time()  \n",
    "        epoch_times.append(end_time - start_time) \n",
    "\n",
    "        if epoch % 10 == 9:  \n",
    "            val_auc, val_ap, val_f1, val_g_mean = test(idx_val, y_val, sne_model, feat_data, edge_indexs, device)\n",
    "            writer.add_scalar('Validation/AUC', val_auc, epoch)\n",
    "            writer.add_scalar('Validation/F1', val_f1, epoch)\n",
    "            writer.add_scalar('Validation/GMean', val_g_mean, epoch)\n",
    "            print(f'Epoch: {epoch}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}, Val F1: {val_f1:.4f}, Val G-mean: {val_g_mean:.4f}')\n",
    "\n",
    "            if val_auc > best_val_auc:\n",
    "                best_val_auc = val_auc\n",
    "                best_model_state = copy.deepcopy(sne_model.state_dict())\n",
    "                \n",
    "    avg_epoch_time = sum(epoch_times) / len(epoch_times)\n",
    "    print(f\"Average epoch time for SNE: {avg_epoch_time:.4f} seconds\")\n",
    "                \n",
    "    print(\"\\n=== Starting Fine-tuning ===\")\n",
    "    sne_model.load_state_dict(best_model_state)\n",
    "    sne_model.classifier.requires_grad_(True)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        filter(lambda p: p.requires_grad, sne_model.parameters()),\n",
    "        lr=args['finetune_lr'],\n",
    "        weight_decay=5e-5\n",
    "    )\n",
    "    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(\n",
    "        optimizer, mode='max', factor=0.5, patience=10, verbose=True\n",
    "    )\n",
    "\n",
    "    best_val_auc = 0.0\n",
    "    best_model_state = None\n",
    "    train_pos, train_neg = pos_neg_split(idx_train, y_train)\n",
    "\n",
    "    no_improve_epochs = 0\n",
    "    early_stop = False\n",
    "\n",
    "    for epoch in range(args['num_epochs']):\n",
    "        if early_stop:\n",
    "            break\n",
    "\n",
    "        sne_model.train()\n",
    "        total_loss = 0.0\n",
    "\n",
    "        optimizer.zero_grad()\n",
    "        out, _ = sne_model(feat_data, edge_indexs)\n",
    "        loss = F.nll_loss(out[idx_train], torch.LongTensor(y_train).to(device))\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        writer.add_scalar('FineTune/Train_Loss', loss.item(), epoch)\n",
    "\n",
    "        if epoch % 5 == 0:\n",
    "            val_auc, val_ap, val_f1, val_g_mean = test(idx_val, y_val, sne_model, feat_data, edge_indexs, device)\n",
    "\n",
    "            writer.add_scalar('Validation/AUC', val_auc, epoch)\n",
    "            writer.add_scalar('Validation/F1', val_f1, epoch)\n",
    "            writer.add_scalar('Validation/GMean', val_g_mean, epoch)\n",
    "\n",
    "            print(f'Epoch: {epoch:03d} | Loss: {loss.item():.4f} | Val AUC: {val_auc:.4f} | Val F1: {val_f1:.4f}')\n",
    "\n",
    "            scheduler.step(val_auc)\n",
    "\n",
    "            if val_auc > best_val_auc:\n",
    "                best_val_auc = val_auc\n",
    "                no_improve_epochs = 0\n",
    "                best_model_state = copy.deepcopy(sne_model.state_dict())\n",
    "            else:\n",
    "                no_improve_epochs += 1\n",
    "\n",
    "            if no_improve_epochs >= args['patience']:\n",
    "                print(f\"Early stopping at epoch {epoch}\")\n",
    "                early_stop = True\n",
    "\n",
    "    sne_model.load_state_dict(best_model_state)\n",
    "    test_auc, test_ap, test_f1, test_g_mean = test(idx_test, y_test, sne_model, feat_data, edge_indexs, device)\n",
    "    print(f'\\n=== Final Test Results ===')\n",
    "    print(f'Test AUC: {test_auc:.4f} | Test AP: {test_ap:.4f} | Test F1: {test_f1:.4f} | G-mean: {test_g_mean:.4f}')\n",
    "    writer.close()\n",
    "    \n",
    "    \n",
    "\n",
    "\n",
    "args = {\n",
    "    \"dataset\": \"amazon\",\n",
    "    \"batch_size\": 32,\n",
    "    \"weight_decay\": 0.00005,\n",
    "    \"pretrain_epochs\": 200,\n",
    "    \"pretrain_lr\": 0.001,\n",
    "    \"finetune_lr\": 0.0005,\n",
    "    \"num_epochs\": 500,\n",
    "    \"pretrain_patience\": 30,\n",
    "    \"tsne_weight\": 0.3,\n",
    "    \"weight\": 0.6,\n",
    "    \"layers\": 4,\n",
    "    \"test_size\": 0.6,\n",
    "    \"val_size\": 0.5,\n",
    "    \"layers_tree\": 1,\n",
    "    \"seed\": 76,\n",
    "    \"num_heads\": 2,\n",
    "    \"drop_rate\": 0.5\n",
    "}\n",
    "\n",
    "sne_main(args)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "c7d2364b",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running SNE pretraining...\n",
      "cuda\n",
      "loading data...\n",
      "\n",
      "=== Starting Pretraining ===\n",
      "Epoch: 9, Val AUC: 0.3459, Val AP: 0.0669, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 19, Val AUC: 0.4312, Val AP: 0.0762, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 29, Val AUC: 0.4563, Val AP: 0.0798, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 39, Val AUC: 0.4147, Val AP: 0.0747, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 49, Val AUC: 0.3852, Val AP: 0.0715, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 59, Val AUC: 0.4167, Val AP: 0.0749, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 69, Val AUC: 0.4504, Val AP: 0.0790, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 79, Val AUC: 0.4893, Val AP: 0.0844, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 89, Val AUC: 0.5098, Val AP: 0.0876, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 99, Val AUC: 0.5645, Val AP: 0.0976, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 109, Val AUC: 0.5669, Val AP: 0.0979, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 119, Val AUC: 0.5694, Val AP: 0.0987, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 129, Val AUC: 0.5835, Val AP: 0.1019, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 139, Val AUC: 0.5840, Val AP: 0.1018, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 149, Val AUC: 0.5893, Val AP: 0.1030, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 159, Val AUC: 0.5958, Val AP: 0.1044, Val F1: 0.0867, Val G-mean: 0.0000\n",
      "Epoch: 169, Val AUC: 0.5943, Val AP: 0.1043, Val F1: 0.0880, Val G-mean: 0.0358\n",
      "Epoch: 179, Val AUC: 0.5759, Val AP: 0.1002, Val F1: 0.1056, Val G-mean: 0.1338\n",
      "Epoch: 189, Val AUC: 0.5734, Val AP: 0.1000, Val F1: 0.1614, Val G-mean: 0.2727\n",
      "Epoch: 199, Val AUC: 0.5563, Val AP: 0.0973, Val F1: 0.1676, Val G-mean: 0.2848\n",
      "Epoch: 209, Val AUC: 0.5512, Val AP: 0.0968, Val F1: 0.2062, Val G-mean: 0.3501\n",
      "Epoch: 219, Val AUC: 0.5322, Val AP: 0.0934, Val F1: 0.2481, Val G-mean: 0.4077\n",
      "Epoch: 229, Val AUC: 0.5503, Val AP: 0.0966, Val F1: 0.2593, Val G-mean: 0.4265\n",
      "Epoch: 239, Val AUC: 0.5095, Val AP: 0.0888, Val F1: 0.2623, Val G-mean: 0.4240\n",
      "Epoch: 249, Val AUC: 0.5043, Val AP: 0.0882, Val F1: 0.2537, Val G-mean: 0.4120\n",
      "Epoch: 299, Val AUC: 0.4923, Val AP: 0.0878, Val F1: 0.1305, Val G-mean: 0.2107\n",
      "Epoch: 309, Val AUC: 0.4662, Val AP: 0.0827, Val F1: 0.1248, Val G-mean: 0.2009\n",
      "Epoch: 319, Val AUC: 0.4998, Val AP: 0.0877, Val F1: 0.1280, Val G-mean: 0.2087\n",
      "Epoch: 329, Val AUC: 0.5131, Val AP: 0.0900, Val F1: 0.1415, Val G-mean: 0.2365\n",
      "Epoch: 339, Val AUC: 0.4986, Val AP: 0.0872, Val F1: 0.1444, Val G-mean: 0.2422\n",
      "Epoch: 349, Val AUC: 0.4870, Val AP: 0.0851, Val F1: 0.1486, Val G-mean: 0.2503\n",
      "Epoch: 359, Val AUC: 0.4933, Val AP: 0.0862, Val F1: 0.1455, Val G-mean: 0.2434\n",
      "Epoch: 369, Val AUC: 0.5002, Val AP: 0.0875, Val F1: 0.1368, Val G-mean: 0.2247\n",
      "Epoch: 379, Val AUC: 0.5095, Val AP: 0.0888, Val F1: 0.1447, Val G-mean: 0.2411\n",
      "Epoch: 389, Val AUC: 0.4912, Val AP: 0.0857, Val F1: 0.1460, Val G-mean: 0.2441\n",
      "Epoch: 399, Val AUC: 0.4930, Val AP: 0.0861, Val F1: 0.1570, Val G-mean: 0.2656\n",
      "Epoch: 409, Val AUC: 0.5165, Val AP: 0.0901, Val F1: 0.1558, Val G-mean: 0.2633\n",
      "Epoch: 419, Val AUC: 0.4932, Val AP: 0.0861, Val F1: 0.1654, Val G-mean: 0.2818\n",
      "Epoch: 429, Val AUC: 0.4997, Val AP: 0.0871, Val F1: 0.1736, Val G-mean: 0.2962\n",
      "Epoch: 439, Val AUC: 0.5044, Val AP: 0.0879, Val F1: 0.1991, Val G-mean: 0.3377\n",
      "Epoch: 449, Val AUC: 0.5016, Val AP: 0.0874, Val F1: 0.2124, Val G-mean: 0.3566\n",
      "Epoch: 459, Val AUC: 0.5255, Val AP: 0.0915, Val F1: 0.2633, Val G-mean: 0.4253\n",
      "Epoch: 469, Val AUC: 0.5195, Val AP: 0.0904, Val F1: 0.2805, Val G-mean: 0.4456\n",
      "Epoch: 479, Val AUC: 0.5297, Val AP: 0.0923, Val F1: 0.3365, Val G-mean: 0.5123\n",
      "Epoch: 489, Val AUC: 0.5362, Val AP: 0.0939, Val F1: 0.3114, Val G-mean: 0.4831\n",
      "Epoch: 499, Val AUC: 0.5436, Val AP: 0.0954, Val F1: 0.3633, Val G-mean: 0.5414\n",
      "Average epoch time for SNE: 3.4191 seconds\n",
      "Running BSNE pretraining...\n",
      "cuda\n",
      "loading data...\n",
      "\n",
      "=== Starting Pretraining ===\n",
      "local_loss: 2.5416290760040283,global_loss: -0.07276153564453125\n",
      "Pretrain Epoch: 000, TSNE Loss: 0.7407\n",
      "local_loss: 2.541555166244507,global_loss: -0.07719773054122925\n",
      "local_loss: 2.600149393081665,global_loss: -0.06129610538482666\n",
      "local_loss: 2.518784761428833,global_loss: -0.06667837500572205\n",
      "local_loss: 2.5855581760406494,global_loss: -0.08442184329032898\n",
      "local_loss: 2.470512866973877,global_loss: -0.045420803129673004\n",
      "Pretrain Epoch: 005, TSNE Loss: 0.7275\n",
      "local_loss: 2.449782609939575,global_loss: -0.05875629559159279\n",
      "local_loss: 2.490546226501465,global_loss: -0.058991771191358566\n",
      "local_loss: 2.4991092681884766,global_loss: -0.08527976274490356\n",
      "local_loss: 2.4424736499786377,global_loss: -0.06326095759868622\n",
      "local_loss: 2.274219274520874,global_loss: -0.043884702026844025\n",
      "Pretrain Epoch: 010, TSNE Loss: 0.6691\n",
      "local_loss: 2.478116989135742,global_loss: -0.11487448215484619\n",
      "local_loss: 2.3764774799346924,global_loss: -0.10060829669237137\n",
      "local_loss: 2.401073455810547,global_loss: -0.09325176477432251\n",
      "local_loss: 2.3497040271759033,global_loss: -0.09230531007051468\n",
      "local_loss: 2.3606104850769043,global_loss: -0.10092014074325562\n",
      "Pretrain Epoch: 015, TSNE Loss: 0.6779\n",
      "local_loss: 2.3371574878692627,global_loss: -0.1099989041686058\n",
      "local_loss: 2.469111680984497,global_loss: -0.17614521086215973\n",
      "local_loss: 2.3285436630249023,global_loss: -0.12974725663661957\n",
      "local_loss: 2.240238904953003,global_loss: -0.13584363460540771\n",
      "local_loss: 2.377485513687134,global_loss: -0.19402246177196503\n",
      "Pretrain Epoch: 020, TSNE Loss: 0.6550\n",
      "local_loss: 2.3163087368011475,global_loss: -0.14824296534061432\n",
      "local_loss: 2.2705140113830566,global_loss: -0.13039986789226532\n",
      "local_loss: 2.2341268062591553,global_loss: -0.1577690690755844\n",
      "local_loss: 2.152979612350464,global_loss: -0.1544172316789627\n",
      "local_loss: 2.171992540359497,global_loss: -0.15713469684123993\n",
      "Pretrain Epoch: 025, TSNE Loss: 0.6045\n",
      "local_loss: 2.1322319507598877,global_loss: -0.17059239745140076\n",
      "local_loss: 2.036839485168457,global_loss: -0.14122891426086426\n",
      "local_loss: 2.0056989192962646,global_loss: -0.19609679281711578\n",
      "local_loss: 1.9986730813980103,global_loss: -0.19381360709667206\n",
      "local_loss: 1.931727409362793,global_loss: -0.19339756667613983\n",
      "Pretrain Epoch: 030, TSNE Loss: 0.5215\n",
      "local_loss: 1.9789973497390747,global_loss: -0.2412262111902237\n",
      "local_loss: 1.9035731554031372,global_loss: -0.23507224023342133\n",
      "local_loss: 1.7632852792739868,global_loss: -0.17831213772296906\n",
      "local_loss: 1.791121244430542,global_loss: -0.2529278099536896\n",
      "local_loss: 1.6987042427062988,global_loss: -0.24162845313549042\n",
      "Pretrain Epoch: 035, TSNE Loss: 0.4371\n",
      "local_loss: 1.5872076749801636,global_loss: -0.20026597380638123\n",
      "local_loss: 1.4762967824935913,global_loss: -0.1980145126581192\n",
      "local_loss: 1.4389777183532715,global_loss: -0.23124831914901733\n",
      "local_loss: 1.3901888132095337,global_loss: -0.2463245540857315\n",
      "local_loss: 1.275395154953003,global_loss: -0.21348609030246735\n",
      "Pretrain Epoch: 040, TSNE Loss: 0.3186\n",
      "local_loss: 1.282406210899353,global_loss: -0.2768767178058624\n",
      "local_loss: 1.145999789237976,global_loss: -0.22875002026557922\n",
      "local_loss: 1.0466265678405762,global_loss: -0.25507739186286926\n",
      "local_loss: 0.9695751667022705,global_loss: -0.22181905806064606\n",
      "local_loss: 0.8861254453659058,global_loss: -0.2269888073205948\n",
      "Pretrain Epoch: 045, TSNE Loss: 0.1977\n",
      "local_loss: 0.8423073887825012,global_loss: -0.2242603451013565\n",
      "local_loss: 0.7229529619216919,global_loss: -0.23091045022010803\n",
      "local_loss: 0.6409521698951721,global_loss: -0.22572416067123413\n",
      "local_loss: 0.5673139691352844,global_loss: -0.20794638991355896\n",
      "local_loss: 0.45951804518699646,global_loss: -0.18445244431495667\n",
      "Pretrain Epoch: 050, TSNE Loss: 0.0825\n",
      "local_loss: 0.3795933127403259,global_loss: -0.16513192653656006\n",
      "local_loss: 0.3473830819129944,global_loss: -0.17554950714111328\n",
      "local_loss: 0.25485631823539734,global_loss: -0.15679717063903809\n",
      "local_loss: 0.24956771731376648,global_loss: -0.16437318921089172\n",
      "local_loss: 0.16668292880058289,global_loss: -0.13902704417705536\n",
      "Pretrain Epoch: 055, TSNE Loss: 0.0083\n",
      "local_loss: 0.12865926325321198,global_loss: -0.12765102088451385\n",
      "local_loss: 0.10174529999494553,global_loss: -0.13819265365600586\n",
      "local_loss: 0.09746576100587845,global_loss: -0.14815093576908112\n",
      "local_loss: 0.07218343019485474,global_loss: -0.15907441079616547\n",
      "local_loss: 0.0382421612739563,global_loss: -0.1591196060180664\n",
      "Pretrain Epoch: 060, TSNE Loss: -0.0363\n",
      "local_loss: 0.03840913623571396,global_loss: -0.16074435412883759\n",
      "local_loss: 0.0425528958439827,global_loss: -0.1623326689004898\n",
      "local_loss: 0.024059874936938286,global_loss: -0.14895625412464142\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "local_loss: 0.03384610638022423,global_loss: -0.16948774456977844\n",
      "local_loss: 0.0270196832716465,global_loss: -0.16048583388328552\n",
      "Pretrain Epoch: 065, TSNE Loss: -0.0400\n",
      "local_loss: 0.021198159083724022,global_loss: -0.1397792249917984\n",
      "local_loss: 0.02606889232993126,global_loss: -0.14141994714736938\n",
      "local_loss: 0.018914751708507538,global_loss: -0.1312311738729477\n",
      "local_loss: 0.022701064124703407,global_loss: -0.1564519852399826\n",
      "local_loss: 0.025902654975652695,global_loss: -0.15903854370117188\n",
      "Pretrain Epoch: 070, TSNE Loss: -0.0399\n",
      "local_loss: 0.02392546832561493,global_loss: -0.1461203396320343\n",
      "local_loss: 0.02353963814675808,global_loss: -0.1226014718413353\n",
      "local_loss: 0.023370105773210526,global_loss: -0.14301535487174988\n",
      "local_loss: 0.0310473944991827,global_loss: -0.10940343141555786\n",
      "local_loss: 0.030668433755636215,global_loss: -0.13407088816165924\n",
      "Pretrain Epoch: 075, TSNE Loss: -0.0310\n",
      "local_loss: 0.018346941098570824,global_loss: -0.09202869981527328\n",
      "local_loss: 0.02056453749537468,global_loss: -0.11899740248918533\n",
      "local_loss: 0.027921399101614952,global_loss: -0.09705604612827301\n",
      "local_loss: 0.024917062371969223,global_loss: -0.07724564522504807\n",
      "local_loss: 0.01727752946317196,global_loss: -0.08999860286712646\n",
      "Pretrain Epoch: 080, TSNE Loss: -0.0218\n",
      "local_loss: 0.022838914766907692,global_loss: -0.06470238417387009\n",
      "local_loss: 0.02960076369345188,global_loss: -0.04932977631688118\n",
      "local_loss: 0.02968231588602066,global_loss: -0.0631481260061264\n",
      "local_loss: 0.035071518272161484,global_loss: -0.05600220710039139\n",
      "local_loss: 0.033131107687950134,global_loss: 0.002014088910073042\n",
      "Pretrain Epoch: 085, TSNE Loss: 0.0105\n",
      "local_loss: 0.015050601214170456,global_loss: 0.013815701007843018\n",
      "local_loss: 0.03019379824399948,global_loss: 0.004519962705671787\n",
      "local_loss: 0.038194697350263596,global_loss: -0.00037968624383211136\n",
      "local_loss: 0.020125120878219604,global_loss: 0.04487861692905426\n",
      "local_loss: 0.015222149901092052,global_loss: 0.018254157155752182\n",
      "Pretrain Epoch: 090, TSNE Loss: 0.0100\n",
      "local_loss: 0.020603716373443604,global_loss: 0.04985425993800163\n",
      "local_loss: 0.023861916735768318,global_loss: 0.060888517647981644\n",
      "local_loss: 0.015489491634070873,global_loss: 0.08790971338748932\n",
      "local_loss: 0.02861098386347294,global_loss: 0.07991833239793777\n",
      "Pretrain early stopping at epoch 94\n",
      "Average epoch time for BSNE: 0.0513 seconds\n",
      "\n",
      "=== Time Comparison ===\n",
      "SNE average epoch time: 3.4191 seconds\n",
      "BSNE average epoch time: 0.0513 seconds\n",
      "BSNE is 0.02 times slower than SNE\n"
     ]
    }
   ],
   "source": [
    "# 时间开销对比\n",
    "\n",
    "import time\n",
    "\n",
    "def calculate_sne_loss(emb, dist_matrix, temperature=1, eps=1e-12):\n",
    "    \"\"\"\n",
    "    :param emb: 全图嵌入 [num_nodes, d_model]\n",
    "    :param dist_matrix: 全图距离矩阵 [num_nodes, num_nodes]\n",
    "    \"\"\"\n",
    "    device = emb.device\n",
    "    num_nodes = emb.size(0)\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        P = (1.0 + dist_matrix ** 2) ** -1\n",
    "        P.fill_diagonal_(0)\n",
    "        P = (P + P.T) / 2 \n",
    "        P = P / (P.sum(dim=1, keepdim=True) + eps)\n",
    "        P = torch.clamp(P, min=eps)\n",
    "\n",
    "    pairwise_dist = torch.cdist(emb, emb, p=2)\n",
    "    Q = (1.0 + pairwise_dist ** 2 / temperature) ** -1\n",
    "    Q.fill_diagonal_(0)\n",
    "    Q = (Q + Q.T) / 2\n",
    "    Q = Q / (Q.sum(dim=1, keepdim=True) + eps)\n",
    "    Q = torch.clamp(Q, min=eps)\n",
    "\n",
    "    loss = (P * (torch.log(P) - torch.log(Q))).sum()\n",
    "\n",
    "    return loss\n",
    "def calculate_tsne_loss(emb_p, emb_u, dist_sub_p, dist_matrix, batch_p_global, batch_u_global, temperature=1,\n",
    "                        eps=1e-12):\n",
    "\n",
    "    device = emb_p.device\n",
    "    batch_size = emb_p.size(0)\n",
    "\n",
    "    # --- 局部项：基于B_p子图 ---\n",
    "    # 生成子图局部索引映射表\n",
    "    subnode_to_local = {node: i for i, node in enumerate(batch_p_global)}\n",
    "    local_indices = [subnode_to_local[node] for node in batch_p_global]\n",
    "\n",
    "    # 提取局部距离矩阵\n",
    "    dist_p = dist_sub_p[local_indices][:, local_indices]\n",
    "\n",
    "    # 计算P\n",
    "    # P = torch.exp(-dist_p ** 2)\n",
    "    P = (1.0 + dist_p ** 2) ** -1\n",
    "    P.fill_diagonal_(0)\n",
    "    P = (P + P.T) / 2  # 对称化\n",
    "    P = P / (P.sum(dim=1, keepdim=True) + eps)\n",
    "    # P = P / P.sum()\n",
    "    P = torch.clamp(P, min=eps)\n",
    "\n",
    "    # 计算Q\n",
    "    # pairwise_dist = torch.cdist(emb_p, emb_p)\n",
    "    pairwise_dist = torch.cdist(emb_p, emb_p, p=2)\n",
    "    Q = (1.0 + pairwise_dist ** 2 / temperature) ** -1\n",
    "    Q.fill_diagonal_(0)\n",
    "    Q = (Q + Q.T) / 2\n",
    "    Q = Q / (Q.sum(dim=1, keepdim=True) + eps)\n",
    "    # Q = Q / Q.sum()\n",
    "    Q = torch.clamp(Q, min=eps)\n",
    "\n",
    "\n",
    "    # 局部损失：KL散度\n",
    "    loss_local = (torch.log(P) - torch.log(Q)).mean()\n",
    "\n",
    "    # --- 全局项 ---\n",
    "\n",
    "    # 计算emb_p到emb_u的距离（平方欧氏距离）\n",
    "    dist_pu_sq = torch.cdist(emb_p, emb_u, p=2) ** 2\n",
    "    d_bu = (1.0 + dist_pu_sq / temperature) ** -1\n",
    "    d_bu = d_bu.sum(dim=1)\n",
    "\n",
    "    pairwise_dist_sq = pairwise_dist ** 2\n",
    "    d_bp = (1.0 + pairwise_dist_sq / temperature) ** -1\n",
    "    d_bp = d_bp.sum(dim=1) + eps\n",
    "\n",
    "    # 计算k_Bp（保持原逻辑）\n",
    "    p_xi_full = (1.0 + dist_matrix ** 2) ** -1 \n",
    "    sum_p_xi = p_xi_full[batch_p_global][:, batch_p_global].sum(dim=1)\n",
    "    k_Bp = (sum_p_xi / p_xi_full[batch_p_global].sum(dim=1)) * (dist_matrix.shape[0] / batch_size)\n",
    "\n",
    "    ratio = (k_Bp.unsqueeze(1) * d_bu) / d_bp.unsqueeze(1)\n",
    "    loss_global = torch.log(ratio.clamp(min=eps)).mean()\n",
    "\n",
    "#     print(len(k_Bp),f\"k_Bp: {k_Bp}\")\n",
    "#     print(f\"d_bu mean: {d_bu.mean().item()}, d_bu max: {d_bu.max().item()}, d_bu min: {d_bu.min().item()}\")\n",
    "#     print(f\"d_bp mean: {d_bp.mean().item()}, d_bp max: {d_bp.max().item()}, d_bp min: {d_bp.min().item()}\")\n",
    "\n",
    "\n",
    "#     print(f\"P mean: {P.mean().item()}, P max: {P.max().item()}, P min: {P.min().item()}\")\n",
    "#     print(f\"Q mean: {Q.mean().item()}, Q max: {Q.max().item()}, Q min: {Q.min().item()}\")\n",
    "    print(f\"local_loss: {loss_local.item()},global_loss: {loss_global.item()}\")\n",
    "\n",
    "    return loss_local + loss_global\n",
    "# SNE算法预训练时间测量\n",
    "def sne_main_with_timing(args):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    print(device)\n",
    "\n",
    "    timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n",
    "    writer = SummaryWriter(f'runs/{args[\"dataset\"]}_{timestamp}')\n",
    "\n",
    "    print('loading data...')\n",
    "    prefix = \"../../data/\"\n",
    "\n",
    "    edge_indexs, feat_data, labels = load_data(args['dataset'], args['layers_tree'], prefix)\n",
    "\n",
    "    np.random.seed(args['seed'])\n",
    "    rd.seed(args['seed'])\n",
    "\n",
    "    if args['dataset'] == 'yelp':\n",
    "        index = list(range(len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels, stratify=labels,\n",
    "                                                                        test_size=args['test_size'], random_state=2,\n",
    "                                                                        shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"YelpChi_shortest_distance.pkl\")\n",
    "    elif args['dataset'] == 'amazon':\n",
    "        index = list(range(3305, len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels[3305:],\n",
    "                                                                        stratify=labels[3305:],\n",
    "                                                                        test_size=args['test_size'],\n",
    "                                                                        random_state=2, shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"Amazon_shortest_distance.pkl\")\n",
    "\n",
    "    with open(dist_path, 'rb') as f:\n",
    "        dist_data = pickle.load(f)\n",
    "        dist_matrix = torch.tensor(dist_data['dist_matrix']).to(device)\n",
    "\n",
    "    adj_dict = defaultdict(list)\n",
    "    for rel in edge_indexs:\n",
    "        edge_index = rel[0].cpu().numpy()\n",
    "        for src, dst in zip(edge_index[0], edge_index[1]):\n",
    "            adj_dict[src].append(dst)\n",
    "\n",
    "    sne_model = multi_HOGRL_Transformer(\n",
    "        in_feat=feat_data.shape[1],\n",
    "        out_feat=2,\n",
    "        relation_nums=len(edge_indexs),\n",
    "        d_model=64,\n",
    "        nhead=2,\n",
    "        num_layers=1,\n",
    "        dim_feedforward=128,\n",
    "        drop_rate=args['drop_rate'],\n",
    "        layers_tree=args['layers_tree'],\n",
    "        tsne_weight=args['tsne_weight']\n",
    "    ).to(device)\n",
    "\n",
    "    for edge_index in edge_indexs:\n",
    "        edge_index[0] = edge_index[0].to(device)\n",
    "        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]\n",
    "    feat_data = torch.tensor(feat_data).float().to(device)\n",
    "\n",
    "    print(\"\\n=== Starting Pretraining ===\")\n",
    "\n",
    "    sne_model.classifier.requires_grad_(False)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        filter(lambda p: p.requires_grad, sne_model.parameters()),\n",
    "        lr=args['pretrain_lr'],\n",
    "        weight_decay=5e-5\n",
    "    )\n",
    "\n",
    "    best_val_auc = 0.0\n",
    "    best_model_state = None\n",
    "    train_pos, train_neg = pos_neg_split(idx_train, y_train)\n",
    "\n",
    "    epoch_times = []  # 用于存储每个epoch的时间\n",
    "\n",
    "    for epoch in range(args['num_epochs']):\n",
    "        start_time = time.time()  # 记录epoch开始时间\n",
    "\n",
    "        sne_model.train()\n",
    "        loss = 0\n",
    "        sampled_idx_train = undersample(train_pos, train_neg, scale=1)\n",
    "        rd.shuffle(sampled_idx_train)\n",
    "\n",
    "        num_batches = int(len(sampled_idx_train) / args['batch_size']) + 1\n",
    "        for batch in range(num_batches):\n",
    "            i_start = batch * args['batch_size']\n",
    "            i_end = min((batch + 1) * args['batch_size'], len(sampled_idx_train))\n",
    "            batch_nodes = sampled_idx_train[i_start:i_end]\n",
    "            batch_label = torch.tensor(labels[np.array(batch_nodes)]).long().to(device)\n",
    "            optimizer.zero_grad()\n",
    "            out, embedding = sne_model(feat_data, edge_indexs)\n",
    "            batch_nodes_tensor = torch.tensor(batch_nodes, dtype=torch.long, device=device)\n",
    "            tsne_loss = calculate_sne_loss(\n",
    "                embedding,\n",
    "                dist_matrix,\n",
    "                temperature=10,\n",
    "                eps=1e-12\n",
    "            )\n",
    "            loss = tsne_loss * args['tsne_weight']\n",
    "            loss.backward()\n",
    "            optimizer.step()\n",
    "\n",
    "        end_time = time.time()  # 记录epoch结束时间\n",
    "        epoch_times.append(end_time - start_time)  # 保存该epoch的时间\n",
    "\n",
    "        if epoch % 10 == 9:  # validate every 10 epochs\n",
    "            val_auc, val_ap, val_f1, val_g_mean = test(idx_val, y_val, sne_model, feat_data, edge_indexs, device)\n",
    "            writer.add_scalar('Validation/AUC', val_auc, epoch)\n",
    "            writer.add_scalar('Validation/F1', val_f1, epoch)\n",
    "            writer.add_scalar('Validation/GMean', val_g_mean, epoch)\n",
    "            print(f'Epoch: {epoch}, Val AUC: {val_auc:.4f}, Val AP: {val_ap:.4f}, Val F1: {val_f1:.4f}, Val G-mean: {val_g_mean:.4f}')\n",
    "\n",
    "            if val_auc > best_val_auc:\n",
    "                best_val_auc = val_auc\n",
    "                best_model_state = copy.deepcopy(sne_model.state_dict())\n",
    "\n",
    "    avg_epoch_time = sum(epoch_times) / len(epoch_times)\n",
    "    print(f\"Average epoch time for SNE: {avg_epoch_time:.4f} seconds\")\n",
    "    return avg_epoch_time\n",
    "\n",
    "\n",
    "# BSNE算法预训练时间测量\n",
    "def bsne_main_with_timing(args):\n",
    "    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n",
    "    print(device)\n",
    "\n",
    "    timestamp = time.strftime(\"%Y%m%d-%H%M%S\")\n",
    "    writer = SummaryWriter(f'runs/{args[\"dataset\"]}_{timestamp}')\n",
    "\n",
    "    print('loading data...')\n",
    "    prefix = \"../../data/\"\n",
    "\n",
    "    edge_indexs, feat_data, labels = load_data(args['dataset'], args['layers_tree'], prefix)\n",
    "\n",
    "    np.random.seed(args['seed'])\n",
    "    rd.seed(args['seed'])\n",
    "\n",
    "    if args['dataset'] == 'yelp':\n",
    "        index = list(range(len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels, stratify=labels,\n",
    "                                                                        test_size=args['test_size'], random_state=2,\n",
    "                                                                        shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"YelpChi_shortest_distance.pkl\")\n",
    "    elif args['dataset'] == 'amazon':\n",
    "        index = list(range(3305, len(labels)))\n",
    "        idx_train_val, idx_test, y_train_val, y_test = train_test_split(index, labels[3305:],\n",
    "                                                                        stratify=labels[3305:],\n",
    "                                                                        test_size=args['test_size'],\n",
    "                                                                        random_state=2, shuffle=True)\n",
    "        idx_train, idx_val, y_train, y_val = train_test_split(idx_train_val, y_train_val,\n",
    "                                                              stratify=y_train_val, test_size=args['val_size'],\n",
    "                                                              random_state=2, shuffle=True)\n",
    "        dist_path = os.path.join(prefix, \"Amazon_shortest_distance.pkl\")\n",
    "\n",
    "    with open(dist_path, 'rb') as f:\n",
    "        dist_data = pickle.load(f)\n",
    "        dist_matrix = torch.tensor(dist_data['dist_matrix']).to(device)\n",
    "\n",
    "    adj_dict = defaultdict(list)\n",
    "    for rel in edge_indexs:\n",
    "        edge_index = rel[0].cpu().numpy()\n",
    "        for src, dst in zip(edge_index[0], edge_index[1]):\n",
    "            adj_dict[src].append(dst)\n",
    "\n",
    "    gnn_model = multi_HOGRL_Transformer(\n",
    "        in_feat=feat_data.shape[1],\n",
    "        out_feat=2,\n",
    "        relation_nums=len(edge_indexs),\n",
    "        d_model=64,\n",
    "        nhead=2,\n",
    "        num_layers=1,\n",
    "        dim_feedforward=128,\n",
    "        drop_rate=args['drop_rate'],\n",
    "        layers_tree=args['layers_tree'],\n",
    "        tsne_weight=args['tsne_weight']\n",
    "    ).to(device)\n",
    "\n",
    "    for edge_index in edge_indexs:\n",
    "        edge_index[0] = edge_index[0].to(device)\n",
    "        edge_index[1] = [tensor.to(device) for tensor in edge_index[1]]\n",
    "    feat_data = torch.tensor(feat_data).float().to(device)\n",
    "\n",
    "    print(\"\\n=== Starting Pretraining ===\")\n",
    "\n",
    "    gnn_model.classifier.requires_grad_(False)\n",
    "    optimizer = torch.optim.AdamW(\n",
    "        filter(lambda p: p.requires_grad, gnn_model.parameters()),\n",
    "        lr=args['pretrain_lr'],\n",
    "        weight_decay=5e-5\n",
    "    )\n",
    "    pretrain_best_loss = float('inf')\n",
    "    pretrain_no_improve = 0\n",
    "    pretrain_early_stop = False\n",
    "\n",
    "    epoch_times = []  # 用于存储每个epoch的时间\n",
    "\n",
    "    for epoch in range(args['pretrain_epochs']):\n",
    "        start_time = time.time()  # 记录epoch开始时间\n",
    "\n",
    "        if pretrain_early_stop:\n",
    "            break\n",
    "\n",
    "        gnn_model.train()\n",
    "        optimizer.zero_grad()\n",
    "\n",
    "        # 第一次采样\n",
    "        batch_centers = rd.sample(range(feat_data.shape[0]), args['batch_size'])\n",
    "        sub_nodes_p = sample_subgraph(batch_centers, dist_matrix, args['sample_size'])\n",
    "\n",
    "        # 第二次采样\n",
    "        batch_u_global = np.random.choice(feat_data.shape[0], size=len(sub_nodes_p), replace=False)\n",
    "\n",
    "        # 生成B_p嵌入\n",
    "        feat_sub_p = feat_data[sub_nodes_p]\n",
    "        _, embeddings_p = gnn_model(feat_sub_p, edge_indexs, sub_nodes=None)\n",
    "\n",
    "        # 生成B_u嵌入\n",
    "        feat_u = feat_data[batch_u_global]\n",
    "        with torch.no_grad():\n",
    "            _, embeddings_u = gnn_model(feat_u, edge_indexs, sub_nodes=None)\n",
    "\n",
    "        # 获取B_p子图距离矩阵\n",
    "        dist_sub_p = dist_matrix[sub_nodes_p][:, sub_nodes_p]\n",
    "\n",
    "        # 计算损失\n",
    "        tsne_loss = calculate_tsne_loss(\n",
    "            embeddings_p,  # 子图嵌入\n",
    "            embeddings_u,  # 全局采样嵌入\n",
    "            dist_sub_p,  # B_p子图距离\n",
    "            dist_matrix,  # 全图距离\n",
    "            sub_nodes_p,  # B_p全局索引\n",
    "            batch_u_global,  # B_u全局索引\n",
    "            temperature=100,\n",
    "            eps=1e-10\n",
    "        ) * args['tsne_weight']\n",
    "\n",
    "        tsne_loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "        end_time = time.time()  # 记录epoch结束时间\n",
    "        epoch_times.append(end_time - start_time)  # 保存该epoch的时间\n",
    "\n",
    "        if tsne_loss.item() < pretrain_best_loss:\n",
    "            pretrain_best_loss = tsne_loss.item()\n",
    "            pretrain_no_improve = 0\n",
    "        else:\n",
    "            pretrain_no_improve += 1\n",
    "\n",
    "        if pretrain_no_improve >= args['pretrain_patience']:\n",
    "            print(f\"Pretrain early stopping at epoch {epoch}\")\n",
    "            pretrain_early_stop = True\n",
    "\n",
    "        writer.add_scalar('Pretrain/TSNE_Loss', tsne_loss.item(), epoch)\n",
    "\n",
    "        if epoch % 5 == 0:\n",
    "            print(f'Pretrain Epoch: {epoch:03d}, TSNE Loss: {tsne_loss.item():.4f}')\n",
    "\n",
    "    avg_epoch_time = sum(epoch_times) / len(epoch_times)\n",
    "    print(f\"Average epoch time for BSNE: {avg_epoch_time:.4f} seconds\")\n",
    "    return avg_epoch_time\n",
    "\n",
    "\n",
    "\n",
    "def main():\n",
    "    args = {\n",
    "        \"dataset\": \"amazon\",\n",
    "        \"batch_size\": 32,\n",
    "        \"weight_decay\": 0.00005,\n",
    "        \"pretrain_epochs\": 200,\n",
    "        \"pretrain_lr\": 0.001,\n",
    "        \"finetune_lr\": 0.0005,\n",
    "        \"num_epochs\": 500,\n",
    "        \"pretrain_patience\": 30,\n",
    "        \"tsne_weight\": 0.3,\n",
    "        \"weight\": 0.6,\n",
    "        \"layers\": 4,\n",
    "        \"test_size\": 0.6,\n",
    "        \"val_size\": 0.5,\n",
    "        \"layers_tree\": 1,\n",
    "        \"seed\": 76,\n",
    "        \"num_heads\": 2,\n",
    "        \"drop_rate\": 0.5,\n",
    "        \"sample_size\": 50  \n",
    "    }\n",
    "\n",
    "    print(\"Running SNE pretraining...\")\n",
    "    sne_avg_time = sne_main_with_timing(args.copy())\n",
    "\n",
    "    print(\"Running BSNE pretraining...\")\n",
    "    bsne_avg_time = bsne_main_with_timing(args.copy())\n",
    "\n",
    "    print(\"\\n=== Time Comparison ===\")\n",
    "    print(f\"SNE average epoch time: {sne_avg_time:.4f} seconds\")\n",
    "    print(f\"BSNE average epoch time: {bsne_avg_time:.4f} seconds\")\n",
    "    print(f\"BSNE is {(bsne_avg_time / sne_avg_time):.2f} times slower than SNE\")\n",
    "\n",
    "\n",
    "main()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "id": "f4ea5a7f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAeYAAAHqCAYAAADRUe20AAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjEsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvc2/+5QAAAAlwSFlzAAAPYQAAD2EBqD+naQAAbuJJREFUeJzt3XlcVNX/P/DXnRk2UUBBRAQRcAFE3HEpBffccsu90rS0tEzNcsu1RUv6ZlZqfnItNbdSczf3fU9zwVAxF1xRBmWTmTm/P/hx5TIzAgrMFV/Px4PHg3nfc+99nzvLe+69Z+6VhBACREREpAoaWydAREREj7EwExERqQgLMxERkYqwMBMREakICzMREZGKsDATERGpCAszERGRirAwExERqQgLMxERkYqwMBOpwOXLlyFJkvy3c+fOZ15mZGSkvLy+ffs+8/LUYsGCBYptRS+2rK+FBQsW2DqdfJHnwiyEwJIlS9CyZUt4enrCzs4Obm5uCAgIQNOmTfHRRx9h9+7dinl27typ2Hg6nQ7nzp1TtHn48KGizcSJExXTs06z9lehQoU8b4DClv0DOPNPq9WiRIkSCA4ORr9+/XD06FGL8xsMBsycORONGzdGqVKlYGdnh1KlSqFSpUp45ZVXMHr0aPz999+KebJ/kLm4uODu3buKNqdPn7b6AreWc/a/yMjIZ9oGkiShWLFiqFixIvr164eTJ0/mZdM+swoVKlh9/VH+yv6ZUFTe33lx8eJFjBw5EvXq1YOHhwfs7e1Rrlw51K9fH2PGjME///xj6xTJRnR5neGNN97A4sWLFTG9Xg+9Xo/Y2Fjs2LEDer0ejRs3troMo9GITz/9FKtWrcp7xkWUyWTCw4cPER0djejoaCxatAi///47Xn31VblNeno6WrZsabY3df/+fdy/fx8XLlzA5s2b4ezsjBo1alhd14MHDzBlyhR88803BdSbp5eSkoKLFy/i4sWL+OWXXzB37ly8+eabtk6rwJUqVQrTpk2THwcGBj7zMt977z20a9cOABAaGvrMy1OLunXrKrbV88ZoNGLChAmYMmUKTCaTYlpcXBzi4uJw6NAhzJw5EwkJCbZJ8jmS9bVQt25dG2aSf/JUmDdu3KgoyvXq1UPz5s3h4OCAq1ev4vz58zhw4ECulvX777/jyJEjT7Uh69Spg+7du5vFXV1d87wsW2vRogVatmwJk8mEs2fPYtGiRRBCwGg0Yvz48YrC/PPPPyuKcosWLdCgQQNoNBpcuXIFZ86cwZEjR3K13pkzZ2LYsGHw8fF56pyz8/X1zfOysi7PaDTi5MmTWLZsGUwmEwwGA9599120aNECZcuWzXE5Dx48QIkSJZ4qh/zwLOt3cXHBiBEj8jUfS+8RtQgMDDQrrlu2bMHWrVvlx2PGjEHJkiXlx5nv76pVq6Jq1aqFk2gBGDRoEObMmSM/dnJyQqdOnRAcHAyDwYAzZ85g06ZNNsxQ/QwGA9LT0+Hk5JTv7xtVEHkwbNgwAUAAEJUqVRJGo9Gszb1798ShQ4cUsR07dsjzZf1r3ry53ObBgweKaRMmTFAsI+u0Pn365CVtha1bt8rLkSRJXLlyRTE9PT1duLu7y22++eYbOf7tt9+K+vXrC1dXV6HVakWpUqVESEiIeOONN8TSpUtztf7Y2Ngn9rNdu3byNAcHB8W0Tp06ydOaNm1qcflxcXHi5MmTitj8+fMtbv+3335bbvPPP/8ops2fPz/XOedVTssbO3asYvrcuXMtzrd9+3bx448/itDQUOHg4CAiIiIUy/njjz9Eu3bthJeXl7CzsxMlS5YUzZs3F6tWrVK069Onj8Xtk/Uvk5+fnyLvv/76SzRu3FiUKFFCbpeSkiLGjBkjWrVqJfz9/YWLi4vQ6XTC3d1dNGrUSHz//fciPT39idtkx44d8rQJEybIcT8/P3H//n0xdOhQ4ePjI+zt7UXlypXFzJkzzbZzRESExfeMpXX98ssvonbt2sLR0VG4u7uLPn36iPj4eLNlJiUliVGjRglfX1/h4OAggoODxffffy8uXbpkNf/cytpPACI2NtZiu+yv5yf1+dChQ6JZs2bC2dlZeHp6ikGDBokHDx4IIYRYsWKFqFWrlnB0dBTe3t5i+PDhIjU11eI6c/taysnGjRsVuVeuXNliPx88eCCioqLM4ocPHxavv/668PPzE/b29qJ48eIiLCxMjB49Wty+fdusffbX64YNG0T9+vWFk5OTKFeunBg7dqx49OiREEKIWbNmieDgYOHg4CD8/f3FF198IUwmk2J5Wd8rERER4vr166JPnz7C09NTODo6itq1a4tly5aZ5bF161bx1ltviRo1aogyZcoIe3t74eTkJCpWrCjeeustcerUKbN5sq/r4sWLolu3bsLd3V1IkiS/xqx9bgmR8VqJiIgQ7u7uQqfTCTc3N1G5cmXRrVs38eOPP5qtMz4+XkyYMEHUrFlTlChRQtjb2wsfHx/RvXt3sXfvXrP2T/vezEmeCvMHH3wgJ+Hu7i7Onz+fq/myF2YvLy/5/23btgkhCq8wm0wmxYv166+/VkzfsGGDPE2n04lbt24JIXL+8K5Xr16u1m+tKBmNRnH27FlRvnx5xROdVfv27RVv6Li4uFytM/sHWeb212q18nOopsK8bt06xfQvvvjC4nwvvfSS4nFmYTYajaJXr15PfL4GDBggr+9pC3P9+vWFVqs1a3fnzp0cl9e8eXNhMBisbhNrhdnd3V0EBQVZXOacOXMU2zG3hTn7dswaz+rRo0eiUaNGFttmfW1mzz+38rswV61aVTg4OJjlGhkZKb755huL/XjjjTcUy8vrayknrVq1Usx79OjRXM/77bffCo1GYzWPMmXKiOPHjyvmyfp6rVmzppAkyWy+Pn36iA8//NDiMseNG6dYXtb3SuXKlUW5cuUszjd9+nTFfIMHD37iNrS3txdbt261uq5KlSoJT09Pi6+xrLGsn1vZX0+WtldWZ86cET4+PlbbS5IkfxZZWkde3ps5ydOh7KznLePj4xEUFISwsDDUrVsXdevWRfPmzREQEJDjckaNGoWPP/4Y6enpGDNmDA4ePJiXNHDmzBlERUWZxRs2bIiGDRs+cV5JktCnTx9MnjwZALBkyRJ8/PHH8vSlS5fK/7dt2xaenp54+PAhfv31VznepUsX1KpVC3q9Hv/99x927dqVp/yzmjRpEiZNmmRx2siRIxWPa9SogT///BMA8O+//6J8+fKoVasW6tSpI59W8Pb2znGd48aNw+DBg2E0GjFu3DgsW7YsTznv37/f4vZv3bp1vhxizH46xMvLy2K7ffv2ISAgAJ07d4ajoyOSk5MBAFOnTsWSJUsAABqNBl27dkVoaChiYmKwePFiGI1GzJkzB7Vr18aAAQPQo0cPhIaG4ssvv8T9+/cBWD9cn9XBgwdRokQJ9O7dG97e3vKAPUmSULFiRdSrVw/e3t4oWbIk0tPTER0djRUrVsBgMOCvv/7CqlWr0K1btzxtm/j4eCQkJKBfv35wd3fHjz/+KPc7KioK77zzTp6WB2RsxwYNGqBZs2ZYt26dPHhw3759OHDgABo0aAAA+O6777Bnzx55vrCwMHTo0AEnT57E2rVr87zegnbmzBn4+fmhd+/eOHToELZt2wYgY+DZzp07Ua1aNXTs2BF//vmn3OfFixdj6tSp8vsor6+lJzGZTIrPiurVq6N27dq56suuXbswfPhwZNQhwN/fHz169MC9e/cwf/58PHr0CLdu3UKnTp1w/vx5ODg4mC3jxIkTqFq1Kjp37oyNGzfKr9eFCxcCAF566SU0bdoUixcvxqVLlwBkPOeffvop7O3tzZb377//wtXVFcOGDYMkSZg3b558TvyTTz5B+/bt5XpQvHhxNGnSBFWrVkWpUqXg5OSE+Ph4rF+/HufOncOjR48wZMgQnD171mL/Y2JiIEkSunbtimrVquHy5ctwdnZ+4jabNWuW/H+zZs3QpEkTJCUl4erVq9i7dy9SUlLk6QaDAZ06dcK1a9cAADqdDn369EGZMmWwYsUKxMTEQAiBsWPHombNmmjdurXZ+vL1vZmXKv7o0SNRvXr1J34LadKkiYiOjlbMl32P+c8//xQDBw6UH//xxx952mO29pfbPblLly4pvjmeO3dOCJFxCDLzkCQAsWbNGiFExuH5zJiLi4tIS0tTLM9kMolLly7lat3Z91as/Q0cONDsMNK9e/esfkMFIDQajejcubPZnnT2PYx//vlH/uYuSZI4fvx4nvaYrf1lP4yU223QokULMW3aNDF16lTRq1cvxV6Bk5OT3J/s81WqVEno9XrFso1Go+JUxJdffqmYPmrUKMX8WWU/7GdJ1jY6nc7iIbhMt27dEmvWrBEzZ84UUVFRYtq0aSI0NFSev1+/fla3ibU9ZgDihx9+kKdNnz5dMS0xMVGelts95vr168uH1uPj4xVHAWbMmCHPV7lyZTleoUIFkZycLE/LftRBDXvMOp1OXsbDhw8V/SpdurS8rc6cOaNY3tq1a4UQz/ZasuT27duK9XTv3j3X26ZDhw7yfCVKlBB37tyRpy1atEix3F9//VWelvX16u7uLr9foqOjFfOEhobKh7TXr1+vmJb1NZ79ed63b588bd++fYpp2fe2jUajOHTokFiwYIGYPn26mDZtmhg+fLhinqynFrOvy9ohYWufQS4uLnL8xo0bZvNdvHhR/v+PP/5QLOenn36Sp92/f1+UKlVKnpb1FOzTvjdzkqfCLIQQCQkJYsSIEcLDw8PqB3T58uUVSVgqzNevXxdOTk4CyDjkpNfrFW0KsjALIURkZKQ83/jx44UQGeecMmNlypRRnAesWrWqPM3b21t06NBBjBgxQixcuFBcu3Yt1+u1VpS++uorMXDgQHmbABBvvfWW2fzXr18X77zzjuILRPa/OnXqKA6TWirMx44dk7+cvPLKKzYtzNb+tFqtmDdvntX5vv32W7Nlnz17NlfLzvzL+gGX18LcoUMHi22Sk5NF3759n3jYEYBo2bKl1b5ZK8xarVakpKTI07Kfs/zvv//kabktzP/73/8U+ZcpU0aeNmnSJCGEEImJiYp5Pv74Y8U8O3futJp/buV3Yc4+5iDrKbS+ffvK8fT0dMXyFi5cKIR4tteSJbdu3VK0z0thLl26tDxft27dFNMMBoOws7OTpw8aNEielvX1mvU1kJqaqshl4sSJ8rSYmBjFtF27dsnTshbLgIAAszz9/f3l6a1bt5bjW7ZsUZyms/a3f/9+i+sqVaqU4jMtK2ufQW3btpXj7u7uok2bNuLDDz8Uc+bMETExMYplfPzxx4rlJCUlKaa/9dZb8rRixYrJ8ad9b+Ykz79jdnV1xbRp03Dr1i2cOnUKc+bMQc+ePeHk5CS3uXLlCn7//fcnLsfb2xvvv/8+gIxDTtl/gvUkffr0gcj4UqH4y8tvT9966y35/8zD11kPY7/xxhvQ6R4f6V+yZAlCQkIAZPykYc2aNYiKikKfPn1Qvnx5DB8+PNfrzqphw4YYMWIEPvnkE8yePRszZ86Up82fPx+HDh1StPf29sacOXNw7949HD58GN9//z06duyoyPXo0aPYu3fvE9dbq1YtvPbaawCATZs2KQ5R5mTChAkWt39+XMTCwcEBAQEB6NOnD44cOaJ4nrKrXLmyWezevXt5Wt+dO3fynOOT1g8Ao0ePxoIFC8x+CpNdWlpantdZpkwZODo6yo+zH7LMaZ2W+Pn5KR5nXWbm8vR6vaJN9tML1k432FK5cuUUj7P2K+u0rO8d4HGf8/u15OHhoXjuoqOjc73szFMsAODp6amYptVq4e7ubrFtVln7nP11k5vtkV32PICM12f2POLi4tCxY0dcuXLF4nKysvaeCAwMhFarzXH+rGbNmoX69esDyDjMvGHDBnz33XcYMGAAKlWqhO7du8t9y7rNihcvjmLFilntV3JyMh49emS2vvx8b+b5d8yZNBoNqlWrhmrVquGdd97BiRMnUKtWLXn6hQsXclzGqFGjMGfOHOj1enz22WdPm8pTee211/D+++/jwYMHiImJwfbt27FhwwZ5evaCEBYWhjNnzuCff/7B8ePHERMTg+PHj2Pjxo0wmUz49ttv8eqrr+b6IhvWhIeHKx4fOHAA9erVM2un0+nkc/vvv/8+/vjjD3Tu3FmefuHCBURERDxxXZ999hl+//13GI3GQt/+WU2YMOGpLuiR/c0DQPHzGgB4++23UaVKFavLsPTh8izrB6A4Z9+kSRPMmTMH/v7+0Gq16NatG1asWPHU67Szs1M8zo8rX+VmmW5uborHt2/fVjy+efPmM+eR37L3K6vsxceS/H4taTQaREREYPPmzQCAkydP4sSJE6hZs2aucsks/Nm3vdFoRHx8vNW8Mz3r9sguex4AcOvWLfn/zNfMn3/+KZ9rlSQJv/76K9q3b48SJUrg7NmzuRqXYu299iS+vr44cOAALly4gMOHDyMmJganTp3C2rVrYTAYsHz5crRu3Rp9+/ZVbLOHDx8iOTlZsc6s/SpWrJjFc+75+d7M07OxcOFCpKamolevXma/1yxevLjicfY3siWlSpXCiBEjMG7cONy4cSMvqTyzYsWKoVu3bpg7dy6AjDddamoqgIzfZ2fuHWf6+++/UaNGDfnLSKbq1avj1KlTAIBjx449c2HO/jtko9Eo//9///d/KFu2LDp16qT4ZgY83favUqUK+vbti7lz5xb69i8oQUFBcHd3lz+o0tLSLP7O8cqVKzh37pxiTyPrGyvzg+RpZP2QbNeuHSpWrAgg44Nsx44dT71cWypevDiCgoLkvbzff/8dkydPlj+g5s+fb8v0CsSzvJasGTJkiFyYAaBXr17YvHkzypcvr2j38OFDzJkzRz4S17BhQ6xZswZAxhGuu3fvwsPDA0DG0bz09HR53pwGwOaXS5cuYf/+/fL69u/fj9jYWHl6nTp1ACjfD66urujRowc0moyDtb/99luB5Xfy5ElUq1YNFStWlN+DANChQwd5sOKxY8fQt29fs23266+/yoP5EhIS5G0PFM72zVNhjo2NxaRJkzB06FA0atQINWrUQMmSJXH79m3FXoIkSTmOaM00dOhQfP/99xa/fVljbVQ2AHzwwQcWRyRa8tZbb8mFOesLytLh0/r168Pb2xuNGjWCt7c3XFxccPLkSbkoA7krhtlljnAWQuDSpUtYtGiRYnrWF8GpU6fw0UcfoUSJEoiIiEC1atVQokQJXL16VXEY3tHRMce95UwTJkzAr7/+mqfDqtZGZTs6OsqnJ2xFo9Fg6NChGDduHADgl19+QUxMDJo2bQpnZ2fExcXh4MGDOH78ON588020atVKnrdcuXLykZ4FCxbA0dERLi4uCAwMRKdOnXKdQ5UqVXD69GkAwOeff45bt25BkiT88ssvZpdCfZ688847+OijjwBkjJJt2LAh2rZti5MnTyo+uIqKZ3ktWdOmTRv0799f/tyJjo5GcHCwfIGR9PR0nD17Fps2bYJGo5EL89ChQ+VtnJiYiPDwcPTo0QP379/HvHnz5OX7+vqiS5cu+b0pntiffv36yaOyM9nZ2cmntrIeZUhISEDr1q3RqFEjHDt2DKtXry6w3Lp37w69Xo8mTZqgXLlyKFWqFC5evKg4Mpr5md2uXTtUqlQJMTExAIDBgwfj8OHD8PLywvLlyxWnNYYNG1ZgOctyfTZa5Py7sMy/kSNHKuazNPgrq++++85sGU8z+AuAuH//fl66pBhpCmSMAk5ISDBrZ+n3kFn//P39Lc6XXW4HPgHmg79y83tbSZLMfjhvafBXVlkvHJP59zSDv1xdXXO1zZ/2d9FPGiCVlcFgED179swx3+y/h7f0OgQg2rZtK7fJzQCxpUuXWlxO2bJlRYsWLeTHWQcn5eUCI1llf29lHTCVlwuMZGWtj0/6HXPr1q0Vj7MOGMqtgrjASG76JYT1AURP+1p6kvT0dPHJJ5/kODgw+/spKirqifOULl3a7HfRT9PnJ70+sn4GhYSEiAoVKljMJfPCTEJkvG6qVatmdbvlZl3ZB/Llph9VqlR54vYtVaqU4jX2zz//CG9v7yfOkzkYMtPTvjdzkqfBX0OHDsXKlSsxaNAghIeHo3z58nBycoK9vT18fX3RuXNnrF+/HlOnTs3LYvHuu++aDUApLNkHLHXu3NnipT1nzZqFt956C2FhYShdujR0Oh2KFy+OsLAwfPLJJzh06NAzXxLUzs4O3t7eaNu2LX777Tf5W3Wmr776Cr/++iveeust1KpVC+XKlYO9vT0cHR0REBCA3r17Y9++fRg0aFCe1jtmzBibXsoyv2m1WixZsgRr1qxBhw4d4O3tDTs7O5QsWRKhoaHo3r07Fi9ejO+++04x3+DBgzFx4kQEBAQ81Tm3TD169MDy5ctRvXp12NnZwd3dHd27d8fBgwdz9TtztbKzs8PGjRsxcuRI+Pj4wN7eHlWqVMG3336LTz/9VNH2aY4eqdHTvpaeRKfT4auvvkJ0dDRGjBiBOnXqyDekKVu2LMLDwzFq1CizmwF99NFH2L9/P3r16gVfX1/Y29ujWLFiqFatGkaOHIl//vkn17+Lzg+lS5fGwYMH0a9fP3h6esLBwQE1a9bE0qVLFYNh7ezssH37dvTt2xfu7u5wcHBAaGgo5syZU6A3i5kyZQreffdd1K5dG15eXrCzs0OxYsUQFBSEQYMG4dixY4obo4SGhuLUqVMYN24catSoAWdnZ9jZ2aFcuXLo2rUrdu/ejfHjxxdYvllJ//8bBxFRjlJSUhS/wMg0YsQI+aYoxYsXR3x8vMUBMvR869u3r3xBkoiIiHy5PSmZe/rdAiJ64TRp0gQBAQFo1KgRfH19cf/+fWzatEkxxmHgwIEsykTPgIWZiHItNTUVS5cuVRTirNq2bYsvvviikLMiKlryfIERUodZs2YhLCwMLi4ucHFxQYMGDbBx48Zczbtv3z7odDqzezafOXMGXbp0QYUKFSBJEqZPn24274MHDzB06FD4+fnByckJDRs2NPuJ1++//45WrVrBw8MDkiTJ1yGm59/777+PVq1aoVy5cnB0dISDgwN8fHzQsWNHrFy5EuvWrcv1ryKIyDKeY35O/fnnn9BqtfLv8xYuXIhp06bJF6q3Rq/Xo1atWqhYsSJu3bqlKJpHjhzB8uXLUbt2bQwbNgwjR47E0KFDFfN3794dp0+fxqxZs+Dt7Y1ff/0V3377Lc6ePStfPeiXX35BbGwsvL295YvPZP8SQERElrEwFyGlSpXCtGnT0L9/f6ttevTogUqVKkGr1WL16tVW92YrVKiAoUOHKgpzSkoKSpQogTVr1qBt27ZyvEaNGmjXrh0+//xzxTIuX74Mf39/FmYiojzgoewiwGg04rfffkNSUpJ8iz5L5s+fj4sXL2LChAlPtR6DwQCj0Wh21TEnJ6ccr81NRES5w8FfBcBkMiEuLg4lSpTIl2sZW3PmzBm0aNECqampKF68OBYvXgwfHx8kJiaatb148SJGjhyJTZs2ITk5GWlpaTCZTBbbAoAQAqmpqWbTw8PDMXHiRPj4+MDT0xMrV67EoUOHEBgYaNb2wYMHADIuL2htPUT0/BJC4MGDB/D29pYvs0nPjoeyC8C1a9fg6+tr6zSIiArF1atX4ePjY+s0igzuMReAzKtoXb16FS4uLoW23ldffRX+/v5mVyFKSEiAn5+f4rZpJpMJQghotVr88ccfZtfWrlatGt577z2rVxFLSkrCgwcP4OXlhb59+yIpKcnsrkn//fcfwsLCsGfPHoSFheVTL4lILRITE+Hr61ukrhyoBizMBSDz8HXmT5kKi1arhRDCbJ3FixfHP//8o4jNnDkT27dvx8qVK+Hv7w9nZ2fFdEmS5Js4WOLi4oKyZcvi/v372L59O77++muztplv1uLFixfqdiCiwlWQp+xeRCzMz6kxY8agdevW8PX1xYMHD/Dbb79h586d2LRpEwBg9OjRuH79OhYtWgSNRoPQ0FDF/J6ennB0dFTEHz16hLNnz8r/X79+HX///TeKFy8u/yxr8+bNEEKgSpUquHDhAj7++GNUqVJFcUeue/fu4cqVK4iLiwMAnD9/HgDg5eUFLy+vgtsoRERFQJE8W79582ZERESgdOnScHBwQEBAAIYPHw69Xv/E+SIjIyFJktlf5j1o1eTWrVt44403UKVKFTRr1gyHDh3Cpk2b0KJFCwDAjRs3cOXKlTwtMy4uDjVr1kTNmjVx48YNREVFoWbNmnj77bflNnq9HoMHD0ZQUBDefPNNvPzyy9iyZYviXsZr165FzZo15Z9U9ejRAzVr1sTs2bPzoedEREVbkRz8tXTpUpw6dQrh4eEoWbIkTp8+jYkTJ6JWrVrYsmWL1fkiIyNhMBjM7jVco0YNs58IPUliYiJcXV2h1+t5CJeIiix+1hWMInkou2fPnujZs6f8ODIyEg4ODhgwYADi4uKeeOs9Nzc31K9fvzDSJCIiMlMkD2Vb4u7uDgBIT0+3cSZERETWFenCbDQakZqaiuPHj2Py5Mlo3749/Pz8njjPrl274OzsDEdHR0RERJjdrJyIiKggFclD2Zn8/Pxw/fp1AMArr7xi9VZ1mSIiIvDmm2+iUqVKiIuLQ1RUFJo3b45du3Y98VKXaWlpSEtLkx9nXuXKYDDAYDAAADQaDTQaDUwmE0wmk9w2M240GpH1dL+1uFarhSRJ8nKzxoGMLyO5iet0OgghFHFJkqDVas1ytBZnn9gn9unF7lP2dVD+KJKDvzKdOnUKDx8+xJkzZ/DZZ5+hYsWK2Lp1q+JCG0+SlJSEqlWrIiQkBBs2bLDabuLEiZg0aZJZ/K+//pJ/H1y6dGkEBgbi4sWLuHPnjtzGx8cHPj4+OHfunGLUeEBAADw9PXHy5EmkpKTI8aCgILi5ueHIkSOKN0xYWBjs7e1x9OhRRQ516tTBo0ePcOrUKTmm1WpRt25dJCQkKEacOzk5oXr16rh9+zYuXbokx11dXREcHIxr167h2rVrcpx9Yp/Ypxe7T0lJSWjevDkHf+WzIl2Yszp27Bjq1KmDFStW4LXXXsv1fIMHD8bKlStx69Ytq20s7TH7+voiPj5efrE+j9+Gc4qzT+wT+/Ri9ykxMRHu7u4szPmsSB/KzqpGjRrQarW4cOFCnubLzfcWBwcHizeH1+l00OmUmzjzzZSdtb14a/Hsy32auCRJFuPWcsxrnH1in6zF2aei0Sdry6JnU6QHf2V14MABGI1GBAQE5HqepKQkrF+/HnXr1i3AzIiIiB4rkl93OnfujDp16iAsLAxOTk44efIkvv76a4SFhaFjx44AgP79+2PhwoXy4Z49e/YgKioKnTp1gp+fH+Li4vDNN9/g5s2bZjdnICIiKihFsjCHh4dj2bJlmDp1KkwmEypUqIABAwZgxIgRsLe3B5BxPiXruZOyZcsiLS0No0ePRnx8PJydndGwYUPMnj0b4eHhtuoKERG9YF6YwV+FiZepI6IXAT/rCsYLc46ZiIjoecDCTEREpCJF8hzz86zCqPW2TkEVLk9ta+sUiIhsgnvMREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqwsJMRESkIizMREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqwsJMRESkIizMREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqwsJMRESkIizMREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqwsJMRESkIizMREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqwsJMRESkIizMREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqwsJMRESkIizMREREKsLCTEREpCIszERERCrCwkxERKQiRbIwb968GREREShdujQcHBwQEBCA4cOHQ6/X5zjvwoULERQUBEdHR4SGhmLFihWFkDEREVGGIlmY7927h4YNG2LOnDnYvHkzhg8fjkWLFqFr165PnG/lypXo27cvOnXqhI0bN6JZs2bo3r07tmzZUkiZExHRi04SQghbJ1EY/ve//2HAgAG4fv06vL29LbYJDg5GtWrVsHz5cjnWqlUr6PV6HDx4MNfrSkxMhKurK/R6PVxcXPKUZ4VR6/PUvqi6PLWtrVMgohw8y2cdWVck95gtcXd3BwCkp6dbnB4bG4vo6Gj07NlTEe/VqxcOHz6Mu3fvFniORERERbowG41GpKam4vjx45g8eTLat28PPz8/i23PnTsHIGOvOauQkBAIIRAdHV3g+RIREelsnUBB8vPzw/Xr1wEAr7zyCpYuXWq17f379wEAbm5uinjJkiUBZJy3tiYtLQ1paWny48TERACAwWCAwWAAAGg0Gmg0GphMJphMJrltZtxoNEIIATtNxpkFowkwQYJOEpCkx+symAABSW6njAN22b5qpZsACYDOLC5BglDEhQAMQoIGAlpLcUlAmyUXkwCMQoJWEtBkiRsFYBLmueelTwaDAVqtNmM+o1GRu7W4TqeDEEIRlyQJWq3WbLtbi+f2ecoprtVqIUmS/PznlDv7xD49j33Kvg7KH0W6MG/YsAEPHz7EmTNn8Nlnn6F9+/bYunWr/EK0RMpaMQD5BZ89ntWUKVMwadIks/iJEyfg7OwMAChdujQCAwMRGxuLO3fuyG18fHzg4+ODf//9F3q9Hn0rZbwhd9+UcF4voVMFE9zsHy9z4zUNriUBvQNNiiK8MlaDhwbI82daEKNBcR3wmv/jeLoJWBCjRTlnoLXP43jCI2BFrBaVXAUaez1+o19LBjZe1aKmu0At98fx83oJu29KeKmMQBXXx/Hj8RKO3ZXQwscEn2KPc8lLn44ePYqwsDDY29vj6NGjij7VqVMHjx49wqlTp+SYVqtF3bp1odfrFUc3nJycUL16ddy9exeXLl2S466urggODkZcXByuXbsmx3P7PGUKCAiAp6cnTp8+jZSUFDkeFBQENzc3nDhxQvHBxj6xT0WpT0lJSaD898IM/jp27Bjq1KmDFStW4LXXXjObvmHDBrRt2xbnzp1DUFCQHD9y5AjCw8OxZ88evPzyyxaXbWmP2dfXF/Hx8fKAiNx+Gw4evwkA95jPTX6Fey3sE/uk8j4lJibC3d2dg7/yWZHeY86qRo0a0Gq1uHDhgsXpmeeWsxfms2fPQpIkRSw7BwcHODg4mMV1Oh10OuUmznwzZZf55kg3KffMDULKqLjZZG/3OG4eE1bjksW4CRJMluJCgslCLkYhwWghbi333PQp63bLvg2fFJckyWLc2nbPa9za0RZr8bzkbi3OPrFPgDr7ZG1Z9GyK9OCvrA4cOACj0YiAgACL0/39/REUFIRly5Yp4kuXLkV4eDg8PDwKI00iInrBFcmvO507d0adOnUQFhYGJycnnDx5El9//TXCwsLQsWNHAED//v2xcOFCxeGeyZMno3v37ggMDESLFi2wZs0abNmyBZs2bbJRT4iI6EVTJAtzeHg4li1bhqlTp8JkMqFChQoYMGAARowYAXv7jFFHRqPR7JxK165dkZycjC+//BJRUVGoWLEili1bhpYtW9qiG0RE9AJ6YQZ/FSZe+evZ8cpfROrHK38VjBfmHDMREdHzQDWHshMTE3Hw4EFcv34dKSkp8PDwQEhICEJDQ22dGhERUaGxaWE2GAxYuXIlZs+ejX379sFkMil+jydJEtzd3dG7d28MGjQIlSpVsmG2REREBc9mh7LXrl2LkJAQvPnmm3B2dsaXX36JLVu24OTJkzh//jwOHDiAX3/9FT169MDq1asREhKCd999lzeTICKiIs1mg79KliyJYcOG4d1334Wnp2eO7bdt24YvvvgCkZGRGD9+fCFk+PQ4+OvZcfAXkfpx8FfBsNmh7NjYWLMbRjxJs2bN0KxZMyQkJBRYTkRERLZms0PZeSnK+TEfERHR80AVP5dKTU2Vb5WYafny5Rg1ahS2bdtmo6yIiIgKnyoK8xtvvIEhQ4bIj2fMmIEePXrg66+/RsuWLbFhwwYbZkdERFR4VFGYDx8+jFdeeUV+PGPGDLz++utISEhA586dERUVZcPsiIiICo8qCvOdO3dQrlw5ABmDwi5duoQPPvgALi4u6N+/P06fPm3jDImIiAqHKgpzsWLFoNfrAQB79uxB8eLFUadOHQCAo6MjHj58aMv0iIiICo0qLslZrVo1/Pjjj/Dz88PMmTPRpEkTSJIEALhy5Qq8vLxsnCEREVHhUEVhHjduHNq1a4caNWrA3t4ef/31lzxt/fr1qFWrlg2zIyIiKjyqKMxNmzbFuXPncOzYMdSoUQMBAQGKaTVq1LBdckRERIVIFYUZAPz8/ODn52cWHzhwoA2yISIisg2bFeYrV67kqX358uULKBMiIiL1sFlhrlChgjzAKzeMRmMBZkNERKQONivM8+bNkwtzeno6Pv/8cxQrVgzdu3eHl5cXbty4gWXLliE5OVn1d5MiIiLKLzYrzH379pX/HzNmDEJCQrBu3TpoNI9/Wj1+/Hi0bdsWMTExNsiQiIio8KniAiOLFi3CoEGDFEUZADQaDQYNGoRffvnFRpkREREVLlUU5vj4eKSkpFiclpKSgvv37xdyRkRERLahisJcq1YtTJ48GXfv3lXE79y5g8mTJ6NmzZo2yoyIiKhwqeJ3zN988w2aN2+OChUqoFmzZvDy8sLNmzflezFnvRIYERFRUaaKPeb69evjyJEjaNeuHQ4dOoR58+bh0KFDaN++PQ4dOoT69evbOkUiIqJCoYo9ZgAIDg7Gb7/9Zus0iIiIbEoVe8xERESUQTV7zHv37sWSJUvw33//mY3QliRJPt9MRERUlKmiMM+fPx/9+/dHqVKlULlyZTg4OCimCyFslBkREVHhUkVh/vrrr9GtWzcsXLjQrCgTERG9SFRxjvm///7D22+/zaJMREQvPFUU5uDgYNy6dcvWaRAREdmcKgrzl19+ialTp+L69eu2ToWIiMimVHGO+ccff4Rer0flypVRo0YNuLu7K6ZLkoQ1a9bYKDsiIqLCo4rCfOrUKWi1Wnh6eiIuLg5xcXGK6Zn3bSYiIirqVFGYL1++bOsUiIiIVEEV55iJiIgogyr2mAEgPT0dixYtwrZt2xAfHw8PDw80b94cr7/+Ouzs7GydHhERUaFQRWHW6/Vo1qwZjh8/DmdnZ3h5eWH//v1YunQpZs6ciW3btsHFxcXWaRIRERU4VRzKHjt2LM6fP49ly5bhwYMHiImJwYMHD7B8+XKcP38eY8eOtXWKREREhUIVhXn16tWYPHkyunbtqoi/9tprmDhxIv744w8bZUZERFS4VFGY79y5g7CwMIvTqlevjrt37xZyRkRERLahisJcrlw57N271+K0ffv2wdvbu5AzIiIisg1VFObu3bvjyy+/xP/93/8hPj4eABAfH4/vvvsOX375JXr06JGn5a1YsQIdO3aEr68vnJ2dERYWhlmzZsFkMj1xvsjISEiSZPYXHR391H0jIiLKC0mo4GbHaWlp6NChA7Zs2QJJkqDT6WAwGCCEQKtWrbBmzRrY29vnenn169eHn58fOnXqhDJlymDHjh2YMmUKhg4dimnTplmdLzIyEgaDAVFRUYp4jRo14OjomOv1JyYmwtXVFXq9Ps+jySuMWp+n9kXV5altbZ0CEeXgWT7ryDpV/FzKwcEBmzZtwubNm7Fjxw7Ex8fD3d0dzZo1Q4sWLfK8vD///BOlS5eWHzdp0gQPHz7EDz/8gM8///yJt5d0c3ND/fr1n6ofREREz0oVhTlTq1at0KpVq2deTtainKlmzZpITU3FvXv3ULZs2WdeBxERUUFQxTnmgwcPYvny5RanLV++HIcOHXrmdezZswelSpWCp6fnE9vt2rULzs7OcHR0REREBHbv3v3M6yYiIsotVewxjxkzBi+99BK6detmNu3s2bP43//+h61btz718o8ePYr58+djwoQJ0Gq1VttFRETgzTffRKVKlRAXF4eoqCg0b94cu3btQoMGDazOl5aWhrS0NPlxYmIiAMBgMMBgMAAANBoNNBoNTCaTYhBaZtxoNEIIATtNxil/owkwQYJOEsh6cy2DCRCQ5HbKOGCX7atWugmQAOjM4hIkCEVcCMAgJGggoLUUlwS0WXIxCcAoJGglAU2WuFEAJmGee176ZDAY5OfKaDQqcrcW1+l0EEIo4pIkQavVmm13a/HcPk85xbVaLSRJkp//nHJnn9in57FP2ddB+UMVg788PDywcOFCtG1rPuBn48aN6NOnD27fvv1Uy7558ybq1asHHx8f7Ny5M0/X3U5KSkLVqlUREhKCDRs2WG03ceJETJo0ySz+119/wdnZGUDG4fXAwEBcvHgRd+7ckdv4+PjAx8cH586dg16vx7ZzGf3cfVPCeb0GXf2NcMsy7m3jNQ2uJUnoW8moKMIrYzV4aAD6VlKOPF8Qo0FxHfCa/+N4uglYEKOFj7NAa5/H8YRHwIpYLaq4mtDY6/HL4loysPGqFrU9TKjl/jh+Xi9h900NGnuZUMX1cfx4vIRjdzVo7WuET7HHueSlT82CPREWFgZ7e3scPXpU0ac6derg0aNHOHXqlBzTarWoW7cuEhISFKPonZycUL16ddy+fRuXLl2S466urggODsa1a9dw7do1OZ7b5ylTQEAAPD09cfLkSaSkpMjxoKAguLm54ciRI4oPNvaJfSpKfUpKSkLz5s05+CufqaIwOzk5YfXq1RbPL2/evBkdO3ZUvPByS6/XIzIyEqmpqdi7dy/c3d3zvIzBgwdj5cqVuHXrltU2lvaYfX19ER8fL79Yc/ttOHj8JgDcYz43+RXutbBP7JPK+5SYmAh3d3cW5nymikPZ/v7+2LFjh8XCvGPHDvj5+eV5mampqXj11Vdx69YtHDhw4KmKMgDk5nuLg4ODxZHeOp0OOp1yE2e+mbLLfHOkmyRF3CCkjIqbTfZ2j+PmMWE1LlmMmyDB0k++TUKCyUIuRiHBaCFuLffc9Cnrdsu+DZ8Uz/y5XXbWtnte49ZOhViL5yV3a3H2iX0C1Nkna8uiZ6OKwV89evTAt99+i/nz5yviCxYswPTp09GzZ888Lc9gMKBbt244efIkNm3a9FSFHcg4lL1+/XrUrVv3qeYnIiLKK1V83Rk1ahR27tyJ/v374/3334e3tzfi4uKQmpqKyMhIjB49Ok/LGzx4MP788098/fXXSE5OxsGDB+VpISEhcHFxQf/+/bFw4UL5cM+ePXsQFRWFTp06wc/PD3Fxcfjmm29w8+ZNrFixIl/7S0REZI0qCrO9vT22bt2KJUuWYOPGjbh79y7Cw8PRunVr9OzZ84kjqS3ZvHkzAOCTTz4xm7Zjxw5ERkbCaDQqzp2ULVsWaWlpGD16NOLj4+Hs7IyGDRti9uzZCA8Pf7YOEhER5ZIqBn8VNbwk57PjJTmJ1I+X5CwYqthjzhQdHY1du3bh7t276N+/P7y8vBAXF4eSJUvCycnJ1ukREREVOFUUZqPRiAEDBmDBggUQQkCSJLRu3RpeXl4YOHAgatasicmTJ9s6TSIiogKnilHZX3zxBZYsWYJp06bh9OnTip8otW7dGps2bbJhdkRERIVHFXvMCxYswLhx4zB8+HCzH7n7+/sjNjbWRpkREREVLlXsMV+/ft3qtagdHR3x4MGDQs6IiIjINlRRmD09PRXXks3q/Pnz8PHxKeSMiIiIbEMVhblNmzb44osvcP36dTkmSRL0ej1mzJiB9u3b2zA7IiKiwqOKwjx58mQYDAaEhISgS5cukCQJY8aMQWhoKFJTUzFu3Dhbp0hERFQoVFGYy5QpgyNHjqBnz544duwYtFotTp48idatW2P//v0oVaqUrVMkIiIqFKoYlQ1kFOfZs2fbOg0iIiKbUsUesyVXr17Fpk2bEB8fb+tUiIiICo0qCvOnn36KYcOGyY//+usvVK5cGW3atEHlypVx5swZG2ZHRERUeFRRmFetWoWQkBD58aeffoqwsDCsXr0afn5++Pzzz22YHRERUeFRxTnm69evo2LFigCA+Ph4HDlyBBs2bECrVq2QmpqKjz76yMYZEhERFQ5V7DELIWAymQAA+/btg1arRePGjQFk3Cf57t27tkyPiIio0KiiMAcGBmLdunUAgN9++w3h4eHybR5v3LiBkiVL2jI9IiKiQqOKQ9kDBw7E4MGDsWjRIiQkJGDevHnytH379inOPxMRERVlqijM7733HkqWLIn9+/cjPDwcr7/+ujwtJSUFffv2tV1yREREhUgSWW9+TPkiMTERrq6u0Ov1cHFxydO8FUatL6Csni+Xp7a1dQpElINn+awj61RxjpmIiIgy2Kwwh4aG4o8//sh1+xs3bmDIkCGYOnVqAWZFRERkWzYrzN26dcObb76J8uXLY/To0di8eTPu3LmDzCPrKSkpOH36NH7++We0b98efn5+OHbsGF599VVbpUxERFTgbHqO+caNG5g+fTrmzZuH+Ph4SJIESZJgZ2eHR48eAcj4jXOjRo3w4YcfonPnzrZKNU94jvnZ8RwzkfrxHHPBsOmo7LJly+Krr77C559/jkOHDuHAgQOIi4tDSkoKPDw8EBQUhMjISPj4+NgyTSIiokKjip9L2dnZ4eWXX8bLL79s61SIiIhsiqOyiYiIVISFmYiISEVYmImIiFSEhZmIiEhFWJiJiIhUhIWZiIhIRVRTmO/cuYPRo0ejQYMGqFy5Ms6cOQMA+Omnn3DixAkbZ0dERFQ4VFGYY2NjUb16dcyYMQOSJOHixYtIS0sDAJw6dQozZsywcYZERESFQxWF+ZNPPoGbmxtiYmKwe/duZL1K6Msvv4x9+/bZMDsiIqLCo4orf23btg2zZs2Ct7c3jEajYlrZsmURFxdno8yIiIgKlyr2mFNTU1GqVCmL05KSkqDRqCJNIiKiAqeKilelShX89ddfFqft3r0boaGhhZwRERGRbajiUPY777yD4cOHw9vbG7179wYAPHr0CCtXrsTMmTPxww8/2DhDIiKiwqGKwjxo0CD8/fffGDZsGD766CMAGYO+hBB455130KdPHxtnSEREVDhUUZgBYM6cOejXrx/Wr1+PW7duwcPDA+3atUPDhg1tnRoREVGhUU1hBoD69eujfv36tk6DiIjIZlRVmAEgOTkZqampZnFro7aJiIiKElUU5uTkZIwZMwaLFy/GvXv3LLbJ/vtmIiKiokgVhfn999/HL7/8gvbt2yM4OBj29vbPtLwVK1Zg8eLFOHbsGO7du4fAwEC89957GDhwYI6/iV64cCGmTJmCy5cvo2LFipgwYQK6du36TPkQERHllioK859//okpU6ZgxIgR+bK8b775Bn5+fpg2bRrKlCmDHTt2YMiQIbh06RKmTZtmdb6VK1eib9++GDVqFFq2bInVq1eje/fucHV1RcuWLfMlNyIioieRRNYLU9tI6dKl8dtvv6FZs2b5srw7d+6gdOnSitjw4cMxa9YsJCQkwMHBweJ8wcHBqFatGpYvXy7HWrVqBb1ej4MHD+Z6/YmJiXB1dYVer4eLi0uecq8wan2e2hdVl6e2tXUKRJSDZ/msI+tUceWvzp07Y8uWLfm2vOxFGQBq1qyJ1NRUq+ewY2NjER0djZ49eyrivXr1wuHDh3H37t18y4+IiMgaVRzK/uabb9ClSxcMHz4cbdq0sTgCu1atWs+0jj179qBUqVLw9PS0OP3cuXMAMvaaswoJCYEQAtHR0Xj55ZefKQciIqKcqKIwp6SkwGAwYPr06fjuu+8U04QQkCTpmUZlHz16FPPnz8eECROg1Wottrl//z4AwM3NTREvWbIkAFjd0waAtLQ0+f7RQMbhHQAwGAwwGAwAAI1GA41GA5PJBJPJJLfNjBuNRgghYKfJOLNgNAEmSNBJApL0eF0GEyAgye2UccAu2zGQdBMgAdCZxSVIEIq4EIBBSNBAQGspLglos+RiEoBRSNBKAposcaMATMI897z0yWAwyM9V9ufeWlyn00EIoYhLkgStVmu23a3Fc/s85RTXarWQJEl+/nPKnX1in57HPmVfB+UPVRTm/v3748iRIxg6dGi+jMrO6ubNm+jSpQvCw8MxcuTIHNtLWSsGIL/gs8ezmjJlCiZNmmQWP3HiBJydnQFkHF4PDAxEbGws7ty5I7fx8fGBj48P/v33X+j1evStlPGG3H1Twnm9hE4VTHDLsjk2XtPgWhLQO9CkKMIrYzV4aIA8f6YFMRoU1wGv+T+Op5uABTFalHMGWvs8jic8AlbEalHJVaCx1+M3+rVkYONVLWq6C9Ryfxw/r5ew+6aEl8oIVHF9HD8eL+HYXQktfEzwKfY4l7z06ejRowgLC4O9vT2OHj2q6FOdOnXw6NEjnDp1So5ptVrUrVsXer0e0dHRctzJyQnVq1fH3bt3cenSJTnu6uqK4OBgxMXF4dq1a3I8t89TpoCAAHh6euL06dNISUmR40FBQXBzc8OJEycUH2zsE/tUlPqUlJQEyn+qGPxVokQJ/N///R/eeeedfF2uXq9HZGQkUlNTsXfvXri7u1ttu2HDBrRt2xbnzp1DUFCQHD9y5AjCw8OxZ88eq4eyLe0x+/r6Ij4+Xh4Qkdtvw8HjNwHgHvO5ya9wr4V9Yp9U3qfExES4u7tz8Fc+U8Uec4kSJVChQoV8XWZqaipeffVV3Lp1CwcOHHhiUQYen1vOXpjPnj0LSZIUsewcHBwsjvTW6XTQ6ZSbOPPNlF3mmyPdpNwzNwgpo+Jmk73d47h5TFiNSxbjJkgwWYoLCSYLuRiFBKOFuLXcc9OnrNst+zZ8UlySJItxa9s9r3Frp0KsxfOSu7U4+8Q+Aersk7Vl0bNRxajsN998E7/99lu+Lc9gMKBbt244efIkNm3aBD8/vxzn8ff3R1BQEJYtW6aIL126FOHh4fDw8Mi3/IiIiKxRxded6tWrY+zYsejUqRPatm1rcVR2586dc728wYMH488//8TXX3+N5ORkxW+QQ0JC4OLigv79+2PhwoWKwz2TJ09G9+7dERgYiBYtWmDNmjXYsmULNm3a9GwdJCIiyiVVFObevXsDAC5fvow1a9aYTc/rqOzNmzcDAD755BOzaTt27EBkZCSMRqPZMrt27Yrk5GR8+eWXiIqKQsWKFbFs2TJe9YuIiAqNKgZ/7dq1K8c2ERERhZBJ/uCVv54dr/xFpH688lfBUMUe8/NUdImIiAqSKgZ/ERERUQab7TH369cP48aNg7+/P/r16/fEtpIkYe7cuYWUGRERke3YrDDv2LEDH374IQBg+/btT7yy1pOmERERFSU2K8yxsbHy/5cvX7ZVGkRERKpis3PMWq0Whw8fttXqiYiIVMlmhVkFv9IiIiJSHY7KJiIiUhGbFmYO6iIiIlKy6QVGPvroI7i5ueXYTpIki5fqJCIiKmpsWpgvXLhg8XaJ2XHPmoiIXhQ2LcyrV69GeHi4LVMgIiJSFQ7+IiIiUhEWZiIiIhVhYSYiIlIRmxXmHTt2ICQkxFarJyIiUiWbDf7iPZiJiIjM8VA2ERGRirAwExERqQgLMxERkYqwMBMREamITa/8ZcmdO3eQkpJiFi9fvrwNsiEiIipcqijMDx48wLBhw7B06VKkpqZabGM0Ggs5KyIiosKnisI8dOhQLFmyBP3790dYWFiubmxBRERUFKmiMK9fvx5Tp07Fhx9+aOtUiIiIbEoVg79SU1NRrVo1W6dBRERkc6oozG3atMGePXtsnQYREZHN2exQ9r179+T/P/30U7z22msoUaIE2rdvD3d3d7P2pUqVKsz0iIiIbMJmhdnDwwOSJMmPhRD4+OOP8fHHH1tsz1HZRET0IrBZYR4/fryiMBMREZENC/PEiRNttWoiIiLVUsXgLyIiIsqgisI8fPhw9O7d2+K0119/3ep5ZyIioqJGFYV57dq1aNmypcVpLVu2xJo1awo5IyIiIttQRWG+fv06KlSoYHGan58frl27VrgJERER2YgqCrOzszOuXr1qcdqVK1fg6OhYyBkRERHZhioKc4MGDfDNN98gPT1dEU9PT8e3336Lhg0b2igzIiKiwqWKm1h8+umnaNy4MUJDQ9G/f3+UK1cO165dw7x58/Dff/9h9uzZtk6RiIioUKiiMNerVw9r167F4MGDMWrUKDkeGBiItWvXIjw83IbZERERFR5VFGYAaNWqFS5cuICYmBjcuXMHpUuXRqVKlWydFhERUaFSTWHOVKlSJRZkIiJ6Yali8BcAXLx4EW+88Qa8vb3h4OCAcuXKoU+fPrh48aKtUyMiIio0qthjjo6ORoMGDZCamoqmTZvC29sbcXFxWL58OdatW4d9+/YhKCjI1mkSEREVOFXsMY8ZMwbu7u6IiYnB+vXr8b///Q/r169HTEwM3N3dMXbs2Dwt78KFC3j33XdRo0YN6HQ6hIaG5mq+yMhISJJk9hcdHf003SIiIsozVewx79q1CzNmzICPj48i7uPjg/Hjx2PIkCF5Wt6ZM2ewfv161KtXDyaTCSaTKdfzvvTSS4iKilLErF2VjIiIKL+pojAnJyfD3d3d4jQPDw+kpKTkaXnt27dHhw4dAAB9+/bF0aNHcz2vm5sb6tevn6f1ERER5RdVHMquUqUKFi9ebHHa0qVL83x+WaNRRbeIiIjyTBV7zEOGDMHbb78NvV6PPn36oGzZsrhx4wZ+/fVXrF27Fj///HOh5bJr1y44OzvDaDSiXr16+Oyzz9C4ceNCWz8REb3YVFGY+/Xrh1u3buHzzz/H+vXrAQBCCDg5OeGLL77AW2+9VSh5RERE4M0330SlSpUQFxeHqKgoNG/eHLt27UKDBg2szpeWloa0tDT5cWJiIgDAYDDAYDAAyNiL12g0Zue8M+NGoxFCCNhpBADAaAJMkKCTBCTp8boMJkBAktsp44BdtoMF6SZAAqAzi0uQIBRxIQCDkKCBgNZSXBLQZsnFJACjkKCVBDRZ4kYBmIR57nnpk8FggFarzZjPaFTkbi2u0+kghFDEJUmCVqs12+7W4rl9nnKKa7VaSJIkP/855c4+sU/PY5+yr4PyhySyPqM2ptfrsX//fty7dw/u7u5o0KABXF1dn2mZmeeYT58+ned5k5KSULVqVYSEhGDDhg1W202cOBGTJk0yi//1119wdnYGAJQuXRqBgYG4ePEi7ty5I7fx8fGBj48Pzp07B71ej23nbgMAdt+UcF6vQVd/I9zsHy9z4zUNriVJ6FvJqCjCK2M1eGgA+lZSDnRbEKNBcR3wmv/jeLoJWBCjhY+zQGufx/GER8CKWC2quJrQ2Ovxy+JaMrDxqha1PUyo5f44fl4vYfdNDRp7mVDF9XH8eLyEY3c1aO1rhE+xx7nkpU/Ngj0RFhYGe3t7szECderUwaNHj3Dq1Ck5ptVqUbduXSQkJChG0Ts5OaF69eq4ffs2Ll26JMddXV0RHByMa9euKW4rmtvnKVNAQAA8PT1x8uRJxViIoKAguLm54ciRI4oPNvaJfSpKfUpKSkLz5s2h1+vh4uICyh+qKswF4VkKMwAMHjwYK1euxK1bt6y2sbTH7Ovri/j4ePnFmttvw8HjNwHgHvO5ya9wr4V9Yp9U3qfExES4u7uzMOczVRzKBjKK2Y8//ogdO3YgPj4e7u7uaNKkCd577z24ubnZLK/cfG9xcHCAg4ODWVyn00GnU27izDdTdplvjnSTpIgbhJRRcbPJ3u5x3DwmrMYli3ETJFj6hZlJSDBZyMUoJBgtxK3lnps+Zd1u2bfhk+KSJFmMW9vueY1nPk+5jecld2tx9ol9AtTZJ2vLomejiuHLsbGxCAsLw9ixYxETEwN7e3vExMRg7NixqF69uuIwUGFKSkrC+vXrUbduXZusn4iIXjyqKMwffvghUlNTsW/fPsTGxuLAgQOIjY3F3r17kZaWhqFDh+ZpecnJyVi5ciVWrlyJ//77D4mJifLjzHM3/fv3V3zb27NnDzp06IAFCxZgx44dWLx4MRo1aoSbN29i/Pjx+dldIiIiq1RxHGL79u347rvvzEY+N2zYEJ9//nmeC/Pt27fRtWtXRSzz8Y4dOxAZGQmj0ag4d1K2bFmkpaVh9OjRiI+Ph7OzMxo2bIjZs2fzftBERFRoVFGYHRwc4Ovra3Fa+fLlLZ6/fZIKFSrkeG54wYIFWLBggfy4YsWK2LRpU57WQ0RElN9UcSi7Q4cOWLFihcVpK1asQLt27Qo5IyIiIttQxR5zr1690L9/f3Tt2hW9evWCl5cXbt68icWLF+Po0aOYO3cujh8/LrevVauWDbMlIiIqOKoozC1btgQAXL16Fb///rsczzwcnTldCAFJksx+b0dERFRUqKIwz58/39YpEBERqYIqCnOfPn1snQIREZEqqGLwV05Mli5DRUREVATZrDAHBATg5MmT8mMhBAYMGICrV68q2h06dAh2dnaFnR4REZFN2KwwX758WXHjB5PJhLlz5yruqkJERPSiUdWh7CJ+oysiIqIcqaowExERvehYmImIiFTEpoVZkszvKWwpRkRE9KKw6e+Ye/XqBScnJ0Wse/fucHR0lB+npKQUdlpEREQ2Y7PC3LhxY7O944iICIttfXx8CiMlIiIim7NZYd65c6etVk1ERKRaHPxFRESkIizMREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqYtMLjGQXHR2NXbt24e7du+jfvz+8vLwQFxeHkiVLml2IhIiIqChSRWE2Go0YMGAAFixYACEEJElC69at4eXlhYEDB6JmzZqYPHmyrdMkIiIqcKo4lP3FF19gyZIlmDZtGk6fPq24/WPr1q2xadMmG2ZHRERUeFSxx7xgwQKMGzcOw4cPh9FoVEzz9/dHbGysjTIjIiIqXKrYY75+/ToaNGhgcZqjoyMePHhQyBkRERHZhioKs6enJy5dumRx2vnz53kTCyIiemGoojC3adMGX3zxBa5fvy7HJEmCXq/HjBkz0L59extmR0REVHhUUZgnT54Mg8GAkJAQdOnSBZIkYcyYMQgNDUVqairGjRtn6xSJiIgKhSoKc5kyZXDkyBH07NkTx44dg1arxcmTJ9G6dWvs378fpUqVsnWKREREhUIVo7KBjOI8e/ZsW6dBRERkU6rYYyYiIqIMqthj7tevn9VpGo0Gbm5uqFu3Ljp16gR7e/tCzIyIiKhwqaIw79ixA3q9HgkJCdDpdHB3d0d8fDwMBgPc3NwghMD//d//oUqVKti5cyfKlClj65SJiIgKhCoOZa9atQolSpTA0qVLkZKSghs3biAlJQVLlixBiRIlsHnzZuzduxf379/HmDFjbJ0uERFRgVHFHvPw4cMxYsQIdO/eXY5ptVr06NEDt27dwvDhw7F3716MHDkSUVFRNsyUiIioYKlij/nIkSMICQmxOC00NBQnTpwAANSoUQN3794tzNSIiIgKlSoKs4uLC3bs2GFx2vbt2+Hi4gIASElJQYkSJQozNSIiokKlikPZvXr1wldffQUhBLp27YoyZcrg1q1bWLZsGb755ht8+OGHAIBjx44hODjYxtkSEREVHFUU5ilTpuDGjRuYMmUKpk6dKseFEOjZsye+/PJLAECDBg3QqlUrW6VJRERU4FRRmO3t7bFkyRKMGzcOu3btQnx8PNzd3dG4cWPFuefmzZvbMEsiIqKCp4rCnCk4OJiHqomI6IWmqsIMAHfu3EFKSopZvHz58jbIhoiIqHCpYlQ2AHz++efw9PSEl5cX/P39zf7y4sKFC3j33XdRo0YN6HQ6hIaG5nrehQsXIigoCI6OjggNDcWKFSvy2hUiIqKnporCPG/ePEydOhVDhgyBEAJjxozB6NGj4ePjg0qVKuHnn3/O0/LOnDmD9evXo2LFilZ/H23JypUr0bdvX3Tq1AkbN25Es2bN0L17d2zZsiWvXSIiInoqkhBC2DqJ2rVro0uXLhg5ciTs7Oxw9OhR1KpVCykpKWjcuDG6deuGjz/+ONfLM5lM0GgyvnP07dsXR48exenTp3OcLzg4GNWqVcPy5cvlWKtWraDX63Hw4MFcrz8xMRGurq7Q6/Xyb7Bzq8Ko9XlqX1RdntrW1ikQUQ6e5bOOrFPFHvOFCxdQv359uZg+evQIAODk5ISPPvoIc+bMydPyMpeTF7GxsYiOjkbPnj0V8V69euHw4cO84hgRERUKVRRmnS5jDJokSXBxccG1a9fkaR4eHrh+/XqB53Du3DkAMBsVHhISAiEEoqOjCzwHIiIiVYzKrlSpEq5evQoAqFu3Lv73v/+hQ4cO0Gg0mDNnDipUqFDgOdy/fx8A4ObmpoiXLFkSAHDv3j2r86alpSEtLU1+nJiYCAAwGAwwGAwAMvbiNRoNTCYTTCaT3DYzbjQaIYSAnSbjzILRBJggQScJSNLjdRlMgIAkt1PGAbtsX7XSTYAEQGcWlyBBKOJCAAYhQQMBraW4JKDNkotJAEYhQSsJaLLEjQIwCfPc89Ing8EArVabMZ/RqMjdWlyn00EIoYhLkgStVmu23a3Fc/s85RTXarWQJEl+/nPKnX1in57HPmVfB+UPVRTm1q1bY/fu3ejTpw9Gjx6NVq1awc3NDTqdDg8fPsS8efMKLRcpa8UA5Bd89nhWU6ZMwaRJk8ziJ06cgLOzMwCgdOnSCAwMRGxsLO7cuSO38fHxgY+PD/7991/o9Xr0rZTxhtx9U8J5vYROFUxws3+8zI3XNLiWBPQONCmK8MpYDR4aIM+faUGMBsV1wGv+j+PpJmBBjBblnIHWPo/jCY+AFbFaVHIVaOz1+I1+LRnYeFWLmu4Ctdwfx8/rJey+KeGlMgJVXB/Hj8dLOHZXQgsfE3yKPc4lL306evQowsLCYG9vj6NHjyr6VKdOHTx69AinTp2SY1qtFnXr1oVer1cc3XByckL16tVx9+5dXLp0SY67uroiODgYcXFxiiM0uX2eMgUEBMDT0xOnT59W/MwvKCgIbm5uOHHihOKDjX1in4pSn5KSkkD5TxWDv7I7cuQIfvvtN0iShLZt26JJkyZPvazcDv7asGED2rZti3PnziEoKEiRS3h4OPbs2YOXX37Z4ryW9ph9fX0RHx8vD4jI7bfh4PGbAHCP+dzkV7jXwj6xTyrvU2JiItzd3Tn4K5/ZfI85NTUVixYtQqNGjeTzu3Xr1kXdunULNY/MdWcvzGfPnoUkSYpYdg4ODnBwcDCL63Q6+fx5psw3U3aZb450k3LP3CCkjIqbTfZ2j+PmMWE1LlmMmyDBZCkuJJgs5GIUEowW4tZyz02fsm637NvwSXFJkizGrW33vMYzn6fcxvOSu7U4+8Q+Aersk7Vl0bOx+eAvR0dHDBkyBLdv37ZpHv7+/ggKCsKyZcsU8aVLlyI8PBweHh42yoyIiF4kqvi6ExAQgJs3b+bb8pKTk7FhwwYAwH///YfExESsXLkSABAREYHSpUujf//+WLhwoeJwz+TJk9G9e3cEBgaiRYsWWLNmDbZs2YJNmzblW25ERERPoorC/OGHH2Lq1Klo3bp1vpynuH37Nrp27aqIZT7esWMHIiMjYTQazc6pdO3aFcnJyfjyyy8RFRWFihUrYtmyZWjZsuUz50RERJQbqhj8NWTIEPzxxx9ISkpC06ZNUbZsWcUoaEmS8N1339kww7zhlb+eHa/8RaR+vPJXwVBFYc7pSl2SJJnt3aoZC/OzY2EmUj8W5oKhikPZJktDgImIiF5ANh+VTURERI+pqjBv3rwZo0ePxjvvvIMrV64AyLjAR9ar4BARERVlqjiUnZycjA4dOmDbtm3yoK/33nsP5cuXR1RUFHx9fREVFWXjLImIiAqeKvaYx44di6NHj2LVqlXQ6/WKy8u1bNkSf/31lw2zIyIiKjyq2GNesWIFPvvsM3Tq1Mls9HX58uXlw9pERERFnSr2mO/cuYOqVatanKbRaBR3TiEiIirKVFGYy5Urh3/++cfitFOnTsHf37+QMyIiIrINVRTmzp0744svvsCJEyfkmCRJ+O+///Dtt9+aXV6TiIioqFJFYZ4wYQK8vb0RHh6OOnXqQJIkvPXWWwgNDYWnpydGjRpl6xSJiIgKhSoKc4kSJbB//3589tlnKF68OAIDA1GsWDGMHj0au3fvhpOTk61TJCIiKhSqGJUNAE5OThg1ahT3jomI6IWmij3mESNG4OzZs7ZOg4iIyOZUUZh//PFHVKtWDeHh4fjpp5+g1+ttnRIREZFNqKIw37x5Ez/88AM0Gg3ee+89lC1bFr1798a2bdtsnRoREVGhUkVhdnV1xXvvvYeDBw/izJkzeP/997Fjxw60aNECfn5+mDBhgq1TJCIiKhSqKMxZBQcH4+uvv8a1a9ewevVqCCHw+eef2zotIiKiQqGaUdlZ/fvvv1iwYAEWLVqEuLg4+Pr62jolIiKiQqGaPeaHDx9i7ty5ePnllxEcHIxvv/0WjRo1wubNm3H58mVbp0dERFQoVLHH3KdPH6xatQrJycmoXbs2fvjhB/Ts2RNubm62To2IiKhQqaIwb9q0CQMHDpQvw5ndnTt3ULp0aRtkRkREVLhUUZivX78OnU6ZihACGzduxNy5c7Fu3TqkpaXZKDsiIqLCo4rCnLUoX7x4EfPmzcPChQtx48YN2Nvbo0uXLjbMjoiIqPCoojCnpqZixYoVmDt3Lvbs2QMhBCRJwvDhwzFq1Ci4u7vbOkUiIqJCYdNR2UeOHMG7774LLy8v9O3bF8ePH0ffvn2xbt06CCHQvn17FmUiInqh2GyPOSwsDGfOnAEANGjQAP369UP37t3h7OzMa2UTEdELy2aF+fTp05AkCW3btsXUqVMREhJiq1SIiIhUw2aHsqdPn46wsDCsW7cO1apVQ4MGDfDzzz/jwYMHtkqJiIjI5mxWmIcMGYITJ07g8OHDGDBgAKKjozFgwACULVsWAwYMgCRJkCTJVukRERHZhM0vyVmnTh3MmjULN27cwMKFC1GnTh2sXLkSQgj0798f33zzDeLj422dJhERUaGweWHO5OjoiDfeeAM7d+7Ev//+i1GjRiE5ORkff/wxb2JBREQvDNUU5qwCAwPx5Zdf4sqVK1i7di1eeeUVW6dERERUKFRxgRFrNBoN2rVrh3bt2tk6FSIiokKhyj1mIiKiFxULMxERkYqwMBMREakICzMREZGKsDATERGpCAszERGRirAwExERqQgLMxERkYqwMBMREalIkS3M//77L1555RU4OzvD09MTH374IVJSUnKcLzIyUr6zVda/6OjoQsiaiIhedKq+JOfTSkhIQNOmTeHn54dVq1bh9u3bGD58OOLj4/Hrr7/mOP9LL72EqKgoRaxChQoFlC0REdFjRbIw//TTT7h//z7+/vtveHh4AAB0Oh169+6NsWPHIjg4+Inzu7m5oX79+oWRKhERkUKRPJS9YcMGNG/eXC7KANClSxc4ODhgw4YNNsyMiIjoyYpkYT537pzZXrGDgwMCAwNx7ty5HOfftWsXnJ2d4ejoiIiICOzevbugUiUiIlIokoey79+/Dzc3N7N4yZIlce/evSfOGxERgTfffBOVKlVCXFwcoqKi0Lx5c+zatQsNGjSwOE9aWhrS0tLkx4mJiQAAg8EAg8EAIOMWlhqNBiaTCSaTSW6bGTcajRBCwE4jAABGE2CCBJ0kIEmP12UwAQKS3E4ZB+yyfdVKNwESAJ1ZXIIEoYgLARiEBA0EtJbikoA2Sy4mARiFBK0koMkSNwrAJMxzz0ufDAYDtFptxnxGoyJ3a3GdTgchhCIuSRK0Wq3ZdrcWz+3zlFNcq9VCkiT5+c8pd/aJfXoe+5R9HZQ/imRhBjJeQNkJISzGs5o0aZLicbt27VC1alV89tlnVg+DT5kyxWw+ADhx4gScnZ0BAKVLl0ZgYCBiY2Nx584duY2Pjw98fHzw77//Qq/Xo2+ljDfk7psSzusldKpggpv942VuvKbBtSSgd6BJUYRXxmrw0AB5/kwLYjQorgNe838cTzcBC2K0KOcMtPZ5HE94BKyI1aKSq0Bjr8dv9GvJwMarWtR0F6jl/jh+Xi9h900JL5URqOL6OH48XsKxuxJa+JjgU+xxLnnp09GjRxEWFgZ7e3scPXpU0ac6derg0aNHOHXqlBzTarWoW7cu9Hq9YgS9k5MTqlevjrt37+LSpUty3NXVFcHBwYiLi8O1a9fkeG6fp0wBAQHw9PTE6dOnFaP+g4KC4ObmhhMnTig+2Ngn9qko9SkpKQmU/ySR9atWEeHp6Yl+/fph6tSpinjVqlXRoEED/Pzzz3la3uDBg7Fy5UrcunXL4nRLe8y+vr6Ij4+Hi4sLgNx/Gw4evwkA95jPTX6Fey3sE/uk8j4lJibC3d0der1e/qyjZ1ck95iDg4PNziWnpaXh4sWL6NevX56Xl9N3FwcHBzg4OJjFdToddDrlJs58M2WX+eZINyn36A1Cyqi42WRv9zhuHhNW45LFuAkSTJbiQoLJQi5GIcFoIW4t99z0Ket2y74NnxSXJMli3Np2z2s883nKbTwvuVuLs0/sE6DOPllbFj2bIjn4q02bNti2bRvi4+Pl2B9//IG0tDS0adMmT8tKSkrC+vXrUbdu3fxOk4iIyEyRLMwDBw6Em5sbOnTogM2bN+OXX37BBx98gN69eytGa/fv31/xjW/Pnj3o0KEDFixYgB07dmDx4sVo1KgRbt68ifHjx9uiK0RE9IIpksch3NzcsH37dnzwwQfo3LkzihUrhp49e+Krr75StDMajYrzJ2XLlkVaWhpGjx6N+Ph4ODs7o2HDhpg9ezbCw8MLuxtERPQCKpKDv2wtMTERrq6uTzUgosKo9QWU1fPl8tS2tk6BiHLwLJ91ZF2RPJRNRET0vGJhJiIiUhEWZiIiIhVhYSYiIlIRFmYiIiIVYWEmIiJSERZmIiIiFWFhJiIiUhEWZiIiIhVhYSYiIlIRFmYiIiIVYWEmIiJSERZmIiIiFWFhJiIiUhEWZiIiIhVhYSYiIlIRFmYiIiIVYWEmIiJSERZmIiIiFWFhJiLKpZkzZ8Lf3x+Ojo6oXbs29uzZ88T2u3btQu3ateHo6IiAgADMnj1bMX3BggWQJMnsLzU1VW6ze/dutG/fHt7e3pAkCatXrzZbz8SJExEUFARnZ2eULFkSzZs3x6FDh/Klz1T4WJiJiHJh2bJlGDp0KMaOHYsTJ06gUaNGaN26Na5cuWKxfWxsLNq0aYNGjRrhxIkTGDNmDIYMGYJVq1Yp2rm4uODGjRuKP0dHR3l6UlISqlevjh9++MFqbpUrV8YPP/yAf/75B3v37kWFChXQsmVL3LlzJ386T4VKEkIIWydR1CQmJsLV1RV6vR4uLi55mrfCqPUFlNXz5fLUtrZOgUihXr16qFWrFmbNmiXHgoOD0bFjR0yZMsWs/ciRI7F27VqcO3dOjr377rs4efIkDhw4ACBjj3no0KFISEjIVQ6SJOGPP/5Ax44dn9gu8zPor7/+QrNmzXK17KfxLJ91ZB33mImIcvDo0SMcO3YMLVu2VMRbtmyJ/fv3W5znwIEDZu1btWqFo0ePIj09XY49fPgQfn5+8PHxQbt27XDixIlnznXOnDlwdXVF9erVn2lZZBsszEREObh79y6MRiPKlCmjiJcpUwY3b960OM/NmzcttjcYDLh79y4AICgoCAsWLMDatWuxdOlSODo64qWXXkJMTEyec1y3bh2KFy8OR0dHfPvtt9i6dSs8PDzyvByyPZ2tEyAiel5IkqR4LIQwi+XUPmu8fv36qF+/vjz9pZdeQq1atfD9999jxowZecqtSZMm+Pvvv3H37l3873//Q7du3XDo0CF4enrmaTlke9xjJiLKgYeHB7Rardne8e3bt832ijN5eXlZbK/T6eDu7m5xHo1Gg7p16z7VHrOzszMqVqyI+vXrY+7cudDpdJg7d26el0O2x8JMRJQDe3t71K5dG1u3blXEt27dioYNG1qcp0GDBmbtt2zZgjp16sDOzs7iPEII/P333yhbtuwz5yyEQFpa2jMvhwofD2UTEeXC8OHD8cYbb6BOnTpo0KAB5syZgytXruDdd98FAIwePRrXr1/HokWLAGSMwP7hhx8wfPhwvPPOOzhw4ADmzp2LpUuXysucNGkS6tevj0qVKiExMREzZszA33//jR9//FFu8/DhQ1y4cEF+HBsbi7///hulSpVC+fLlkZSUhC+++AKvvvoqypYti/j4eMycORPXrl1D165dC2nrUH5iYSYiyoXu3bsjPj4ekydPxo0bNxAaGooNGzbAz88PAHDjxg3Fb5r9/f2xYcMGDBs2DD/++CO8vb0xY8YMdOnSRW6TkJCAAQMG4ObNm3B1dUXNmjWxe/duhIeHy22OHj2KJk2ayI+HDx8OAOjTpw8WLFgArVaL6OhoLFy4EHfv3oW7uzvq1q2LPXv2oGrVqgW9WagA8HfMBYC/Y352/B0zkfrxd8wFg+eYiYiIVISFmYiISEVYmImIiFSEhZmIiEhFWJiJiIhUhIWZiIhIRViYiYiIVISFmYiISEVYmImIiFSEhZmIiEhFWJiJiIhUhIWZiIhIRViYiYiIVKTI3vbx33//xZAhQ7Bnzx44OzujZ8+emDp1KpycnHKcd+HChZgyZQouX76MihUrYsKECbyvKVFRNNHV1hmox0S9rTOg/69I7jEnJCSgadOmePDgAVatWoWoqCgsXrwY77zzTo7zrly5En379kWnTp2wceNGNGvWDN27d8eWLVsKIXMiInrRFck95p9++gn379/H33//DQ8PDwCATqdD7969MXbsWAQHB1udd9y4cejatSumTJkCAGjSpAmio6Mxfvx4tGzZslDyJyKiF1eR3GPesGEDmjdvLhdlAOjSpQscHBywYcMGq/PFxsYiOjoaPXv2VMR79eqFw4cP4+7duwWWMxEREVBEC/O5c+fM9oodHBwQGBiIc+fOPXE+AGbzhoSEQAiB6Ojo/E+WiIgoiyJ5KPv+/ftwc3Mzi5csWRL37t174nwAzOYtWbIkAFidNy0tDWlpafJjvV4vtzcYDAAAjUYDjUYDk8kEk8kkt82MG41GCCGgTU8CABhNgAkSdJKAJD1el8EECEiw0whFDhlxwC7bV610EyAB0JnFJUgQirgQgEFI0EBAaykuCWiz5GISgFFI0EoCmixxowBMwjz3vPTp3r170Gq1GfMZjYrcrcV1Oh2EEIq4JEnQarVm291aPLfPU05xrVYLSZLk5z+n3NknG/UpDZAgYICdMncYAAgYzeLpACQYs3106pAOkS0uQUALA0zQwASthbgWpiz7RhqYoIHRatwIHQSkLHEjNDCZxbUwPF2fsn2+5eZ5SkxMBADFNqVnVyQLM5DxAspOCGExntO8mS86a/NOmTIFkyZNMov7+/vnJlWywP3/bJ0B0QtmqvtTz/rgwQO4unKEe34pkoW5ZMmS8t5vVgkJCU8c+JW5Z3z//n2UKVNGMV/W6dmNHj0aw4cPlx+bTCbcu3cP7u7uufoioCaJiYnw9fXF1atX4eLiYut0iIq85/k9J4TAgwcP4O3tbetUipQiWZiDg4PNziWnpaXh4sWL6Nev3xPnAzLONQcFBcnxs2fPQpIkRSwrBwcHODg4KGKWDqU/T1xcXJ67Dwmi59nz+p7jnnL+K5KDv9q0aYNt27YhPj5ejv3xxx9IS0tDmzZtrM7n7++PoKAgLFu2TBFfunQpwsPDFaO8iYiICkKRLMwDBw6Em5sbOnTogM2bN+OXX37BBx98gN69eysOZffv3x86nfKgweTJk7F8+XKMHTsWO3fuxLBhw7BlyxZMnjy5sLtBREQvoCJ5KNvNzQ3bt2/HBx98gM6dO6NYsWLo2bMnvvrqK0U7o9FoNmK0a9euSE5OxpdffomoqChUrFgRy5Yte2EuLuLg4IAJEyaYHZonooLB9xxlJwmOcyciIlKNInkom4iI6HnFwkxERKQiLMxF3OLFixEeHg5XV1e4uLggODgYb7/9Nm7fvi23iYyMhCRJGDdunNn8kZGRaNeunfx4586dkCTJ4l/z5s0LpU9EajBx4kTF69/JyQlVq1bF9OnTFVfCOnjwIFq3bg0vLy84OTmhQoUKeO2113Do0CGzZTVu3NjieooXL66IWXsPZh/MSs8nPotF2NSpUzFmzBgMGzYMkydPhhACp0+fxuLFixEXFwdPT09F+xkzZmD48OFWL6SS1fz5881+183fM9KLxsnJCdu3bwcAJCcnY8uWLRg2bBh0Oh3ef/997N27F02aNMErr7yC2bNnw8XFBTExMVi9ejUOHz6MevXqKZa3Z88ebN++HU2bNs1x3R988AF69eqliD1vFzQiy1iYi7Dvv/8effv2xTfffCPHWrdujY8//lhxfWEAqFevHk6fPo3p06dbvLxodqGhoahTp06+50z0PNFoNKhfv778uGnTpjh8+DB+//13vP/++5g1axYqVKiA1atXy9cOb9q0KQYOHGj2HnR2dkZoaCgmTZqUq8Jcvnx5xbqp6OCh7CIsISEBZcuWtThNo1E+9R4eHhg0aBC+++47+RKkRJR3JUqUQHp6OoCM96Cnp6dclLPK/h4EgPHjx2P37t3YuXNnQadJKsbCXITVrl0bs2fPxs8//4ybN2/m2H7EiBFIT0/H9OnTc2xrNBphMBgUf9n3AIheBJmv/8TERKxcuRKbNm3Ca6+9BiDjPbh//36MGzcuV7eNbdOmDerWrYuJEyfm2NZkMpm9B7Nfl4GeTyzMRdjMmTNRqlQpvPPOOyhbtiwCAgLw4Ycf4vLlyxbbe3p64r333sN3330n37rSmvr168POzk7xx6uj0YsmKSlJfv27urqia9eu6NWrF4YMGQIA+Pjjj9GiRQt8/vnnCA4Ohru7O3r37o09e/ZYXeb48eOxa9cu7Nq164nrHjlypNl7sFmzZvnaP7INnmMuwkJDQ3HmzBn89ddf2LJlC3bt2oUZM2Zg/vz52L17N2rUqGE2z8cff4yZM2fiu+++w/jx460ue9GiRWZ36uIdZuhF4+TkhN27dwPIuFHOsWPHMH78eNjb2+Onn35CiRIlsGXLFhw+fBjr16/H3r17sWLFCixduhRz5szB22+/bbbMdu3aoVatWpg0aZI8sMySDz/8EK+//roiVqJEifztINkEC3MRZ29vjzZt2sg379i8eTPatm2LyZMn4/fffzdrX6ZMGQwcOBDTp0/H0KFDrS43ODiYg7/ohafRaBTvg5deegnp6ekYMWIEhgwZgqpVqwIAwsPDER4eDgCIjY1FREQEPvnkE4uFGcjYa+7YseMT96x9fHz4HiyieCj7BdOqVStUr17d7LaYWX3yySdISUnBjBkzCjEzoqIhJCQEAHD69GmL0/39/dG1a1fcv38ft27dstimQ4cOqFGjRq5+IUFFDwtzEWbpTZ+SkoKrV6/Cy8vL6nxly5bFgAED8O233yIxMbEgUyQqcjILsoeHh9XC+++//8LBweGJ920fP348tm3bhr179xZEmqRiPJRdhFWrVg3t27dHq1atULZsWcTFxeH777/H3bt38eGHHz5x3pEjR+Knn37CiRMnLJ47Pn36NAwGgyLm4OCAmjVr5msfiNTMZDLh4MGDAIBHjx7h2LFj+PzzzxESEoLGjRujS5cuMBgM6NKlCypVqoTExESsWrUK69atw9ChQ594R6mOHTsiLCwM27Ztg7Ozs9n0K1euyOvOqmbNmrxT1fNOUJH1448/ildeeUWUK1dO2NvbC29vb/HKK6+I7du3K9pFRESItm3bms3//vvvCwCKaTt27BAALP75+fkVdJeIVGPChAmK179OpxP+/v5i0KBB4tatW0IIITZt2iR69eolAgIChJOTk3B3dxfh4eFi7ty5wmAwKJbl7Oxsto6VK1cKAGbTrL0HAYjY2NgC7TcVPN72kYiISEV4jpmIiEhFWJiJiIhUhIWZiIhIRViYiYiIVISFmYiISEVYmImIiFSEhZmIiEhFWJiJiIhUhIWZ6CnNmDEDkiQhNDTU4nRJknJ1w/uCEhkZicjISPlxcnIyJk6ciJ07d5q1nThxIiRJwt27dwsvQSKyiNfKJnpK8+bNAwCcOXMGhw4dQr169WyckdLMmTMVj5OTk+W7FWUt2ESkLtxjJnoKR48excmTJ9G2bVsAwNy5c22c0WPJyckAMm4/mHkLQiJ6frAwEz2FzEI8depUNGzYEL/99ptcEJ9k7969aNCgARwdHVGuXDmMGzcOP//8MyRJwuXLl+V2JpMJX3/9NYKCguDg4ABPT0+8+eabuHbtmmJ5kZGRCA0Nxe7du9GwYUMUK1YM/fr1k6dl7hlfvnwZpUuXBgBMmjQJkiRBkiT07dtXsbxbt26hZ8+ecHV1RZkyZdCvXz/o9XpFG0mS8P7772P+/PmoUqUKnJycUKdOHRw8eBBCCEybNg3+/v4oXrw4mjZtigsXLuRl0xK98FiYifIoJSUFS5cuRd26dREaGop+/frhwYMHWLFixRPnO3XqFFq0aIHk5GQsXLgQs2fPxvHjx/HFF1+YtX3vvfcwcuRItGjRAmvXrsVnn32GTZs2oWHDhmbngW/cuIHXX38dvXr1woYNGzBo0CCz5ZUtWxabNm0CAPTv3x8HDhzAgQMHMG7cOEW7Ll26oHLlyli1ahVGjRqFJUuWYNiwYWbLW7duHX7++WdMnToVS5cuxYMHD9C2bVt89NFH2LdvH3744QfMmTMHZ8+eRZcuXcB75RDlgW1vbkX0/Fm0aJEAIGbPni2EEOLBgweiePHiolGjRop2AMSECRPkx127dhXOzs7izp07csxoNIqQkBDF7frOnTsnAIhBgwYplnfo0CEBQIwZM0aORURECABi27ZtZnlGRESIiIgI+fGdO3fMcsqUeQvDr7/+WhEfNGiQcHR0FCaTSdEvLy8v8fDhQzm2evVqAUDUqFFD0Xb69OkCgDh16pTZOonIMu4xE+XR3Llz4eTkhB49egAAihcvjq5du2LPnj2IiYmxOt+uXbvQtGlTeHh4yDGNRoNu3bop2u3YsQMAzA4zh4eHIzg4GNu2bVPES5YsiaZNmz5Ll2Svvvqq4nFYWBhSU1Nx+/ZtRbxJkyZwdnaWHwcHBwMAWrduDUmSzOL//fdfvuRH9CJgYSbKgwsXLmD37t1o27YthBBISEhAQkICXnvtNQCPR2pbEh8fjzJlypjFs8fi4+MBZBx+zs7b21uenslSu6fl7u6ueOzg4AAg4/B9VqVKlVI8tre3f2I8NTU133IkKupYmInyYN68eRBCYOXKlShZsqT8lzk6e+HChTAajRbndXd3x61bt8ziN2/eNGsHZJw7zi4uLk6xxw1AsYdKRM8/FmaiXDIajVi4cCECAwOxY8cOs7+PPvoIN27cwMaNGy3OHxERge3btysGb5lMJrNBY5mHpX/99VdF/MiRIzh37hyaNWv2VPlb2/slInXhBUaIcmnjxo2Ii4vDV199ZfECHaGhofjhhx8wd+5ctGvXzmz62LFj8eeff6JZs2YYO3YsnJycMHv2bCQlJQHION8MAFWqVMGAAQPw/fffQ6PRoHXr1rh8+TLGjRsHX19fi6Okc6NEiRLw8/PDmjVr0KxZM5QqVQoeHh6oUKHCUy2PiAoG95iJcmnu3Lmwt7fHW2+9ZXG6h4cHOnXqhHXr1lk8ZF29enVs3boVTk5OePPNNzFgwABUrVpV/nmTq6ur3HbWrFmYOnUqNmzYgHbt2mHs2LFo2bIl9u/fb3YeOK99KFasGF599VXUrVvXppcMJSLLJCH4A0MiW2rZsiUuX76Mf//919apEJEK8FA2USEaPnw4atasCV9fX9y7dw+LFy/G1q1bVXVJTyKyLRZmokJkNBoxfvx43Lx5E5IkISQkBL/88gtef/11W6dGRCrBQ9lEREQqwsFfREREKsLCTEREpCIszERERCrCwkxERKQiLMxEREQqwsJMRESkIizMREREKsLCTEREpCIszERERCry/wDjBRlFgyhl8QAAAABJRU5ErkJggg==",
      "text/plain": [
       "<Figure size 400x500 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "<Figure size 640x480 with 0 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "import matplotlib.pyplot as plt\n",
    "\n",
    "\n",
    "labels = ['SNE', 'BSNE']\n",
    "times = [3.4191, 0.0513]\n",
    "x = range(len(labels))\n",
    "\n",
    "\n",
    "plt.figure(figsize=(4, 5))\n",
    "bars = plt.bar(x, times, color=['#1f77b4', '#ff7f0e'], width=0.3)\n",
    "\n",
    "for bar in bars:\n",
    "    height = bar.get_height()\n",
    "    plt.text(bar.get_x() + bar.get_width()/2., height + 0.1,\n",
    "             f'{height:.4f}', ha='center', va='bottom')\n",
    "\n",
    "\n",
    "plt.title('SNE vs BSNE Pretraining Time Comparison', fontsize=14, fontweight='bold')\n",
    "plt.xlabel('Algorithm', fontsize=12)\n",
    "plt.ylabel('Average Epoch Time (seconds)', fontsize=12)\n",
    "\n",
    "plt.xticks(x, labels, fontsize=11)\n",
    "plt.yticks(fontsize=11)\n",
    "\n",
    "plt.grid(axis='y', linestyle='--', alpha=0.7)\n",
    "\n",
    "plt.tight_layout()\n",
    "\n",
    "plt.show()\n",
    "\n",
    "plt.savefig('sne_vs_bsne_time_comparison.png', dpi=300, bbox_inches='tight')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ec9b34a3",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "tsne01",
   "language": "python",
   "name": "tsne01"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
